forked from mindspore-Ecosystem/mindspore
commit
65dabb58ef
|
@ -0,0 +1,245 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [AttGAN描述](#AttGAN描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [推理过程](#推理过程)
|
||||
- [导出MindIR](#导出MindIR)
|
||||
- [在Ascend310执行推理](#在Ascend310执行推理)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [CelebA上的AttGAN](#CelebA上的AttGAN)
|
||||
- [推理性能](#推理性能)
|
||||
- [CelebA上的AttGAN](#CelebA上的AttGAN)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# AttGAN描述
|
||||
|
||||
AttGAN指的是AttGAN: Facial Attribute Editing by Only Changing What You Want, 这个网络的特点是可以在不影响面部其它属性的情况下修改人脸属性。
|
||||
|
||||
[论文](https://arxiv.org/abs/1711.10678):[Zhenliang He](https://github.com/LynnHo/AttGAN-Tensorflow), [Wangmeng Zuo](https://github.com/LynnHo/AttGAN-Tensorflow), [Meina Kan](https://github.com/LynnHo/AttGAN-Tensorflow), [Shiguang Shan](https://github.com/LynnHo/AttGAN-Tensorflow), [Xilin Chen](https://github.com/LynnHo/AttGAN-Tensorflow), et al. AttGAN: Facial Attribute Editing by Only Changing What You Want[C]// 2017 CVPR. IEEE
|
||||
|
||||
# 模型架构
|
||||
|
||||
整个网络结构由一个生成器和一个判别器构成,生成器由编码器和解码器构成。该模型移除了严格的attribute-independent约束,仅需要通过attribute classification来保证正确地修改属性,同时整合了attribute classification constraint、adversarial learning和reconstruction learning,具有较好的修改面部属性的效果。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
|
||||
|
||||
CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据集,拥有超过 200K 的名人图像,每个图像有 40 个属性注释。 CelebA 多样性大、数量多、注释丰富,包括
|
||||
|
||||
- 10,177 number of identities,
|
||||
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
|
||||
|
||||
该数据集可用作以下计算机视觉任务的训练和测试集:人脸属性识别、人脸检测以及人脸编辑和合成。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行训练示例
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||
OR
|
||||
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
|
||||
# 运行评估示例
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||
OR
|
||||
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
|
||||
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将检查点文件(Generator)的名称作为参数传入。
|
||||
|
||||
[注意] 以上路径均应设置为绝对路径
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```text
|
||||
.
|
||||
└─ cv
|
||||
└─ AttGAN
|
||||
├── ascend310_infer # 310推理目录
|
||||
├── scripts
|
||||
├──run_distribute_train.sh # 分布式训练的shell脚本
|
||||
├──run_single_train.sh # 单卡训练的shell脚本
|
||||
├──run_eval.sh # 评估脚本
|
||||
├──run_infer_310.sh # 推理脚本
|
||||
├─ src
|
||||
├─ __init__.py # 初始化文件
|
||||
├─ block.py # 基础cell
|
||||
├─ attgan.py # 生成网络和判别网络
|
||||
├─ utils.py # 辅助函数
|
||||
├─ cell.py # loss网络wrapper
|
||||
├─ data.py # 数据加载
|
||||
├─ helpers.py # 进度条显示
|
||||
├─ loss.py # loss计算
|
||||
├─ eval.py # 测试脚本
|
||||
├─ train.py # 训练脚本
|
||||
├─ export.py # MINDIR模型导出脚本
|
||||
├─ preprocess.py # 310推理预处理脚本
|
||||
├─ postprocess.py # 310推理后处理脚本
|
||||
└─ README_CN.md # AttGAN的文件描述
|
||||
```
|
||||
|
||||
该脚本可以修改13种属性,分别为:Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python train.py --img_size 128 --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||
```
|
||||
|
||||
训练结束后,当前目录下会生成output目录,在该目录下会根据你设置的experiment_name参数生成相应的子目录,训练时的参数保存在该子目录下的setting.txt文件中,checkpoint文件保存在`output/experiment_name/rank0`下。如果需要生成setting.txt文件,只需要执行一次train.py文件即可,此时脚本会根据设定的参数生成对应的setting.txt文件。该文件会被用于推理脚本。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估自定义数据集
|
||||
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。
|
||||
|
||||
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数)
|
||||
|
||||
```bash
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||
```
|
||||
|
||||
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
||||
```shell
|
||||
python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
`experiment_name` 是output目录下的存放结果的文件夹的名称,此参数用于帮助export寻找参数
|
||||
|
||||
脚本会在当前目录下生成对应的MINDIR文件。
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
||||
在执行推理前,必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑:
|
||||
|
||||
```bash
|
||||
bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `MINDIR_PATH` MINDIR文件的路径
|
||||
- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径
|
||||
- `DATA_PATH` 需要进行推理的数据集目录,图像格式应当为jpg
|
||||
- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理,可以在y或者n中选择,如果选择y,表示进行预处理(在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间)
|
||||
- `DEVICE_ID` 可选,默认值为0.
|
||||
|
||||
[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件,第一行为要推理的图片数目,第二行为要编辑的属性,接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果,是否对该属性进行编辑根据属性编辑文件的内容决定。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### CelebA上的AttGAN
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | AttGAN |
|
||||
| 资源 | Ascend |
|
||||
| 上传日期 | 06/30/2021 (month/day/year) |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | CelebA |
|
||||
| 训练参数 | batch_size=32, lr=0.0002 |
|
||||
| 优化器 | Adam |
|
||||
| 生成器输出 | image |
|
||||
| 速度 | 5.56 step/s |
|
||||
| 脚本 | [AttGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/AttGAN) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### CelebA上的AttGAN
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | AttGAN |
|
||||
| 资源 | Ascend |
|
||||
| 上传日期 | 06/30/2021 (month/day/year) |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | CelebA |
|
||||
| 推理参数 | batch_size=1 |
|
||||
| 输出 | image |
|
||||
|
||||
推理完成后可以获得对原图像进行属性编辑后的图片slide.
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
void Denorm(std::vector<mindspore::MSTensor> *outputs);
|
||||
std::string RealPath(std::string_view path);
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision_ascend.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using color_rep_type = std::underlying_type<mindspore::DataType>::type;
|
||||
|
||||
DEFINE_string(gen_mindir_path, "", "generator mindir path");
|
||||
DEFINE_string(dataset_path, "", "dataset path");
|
||||
DEFINE_string(attr_file_path, "", "attribute file path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
DEFINE_int32(image_height, 128, "image height");
|
||||
DEFINE_int32(image_width, 128, "image width");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_gen_mindir_path).empty()) {
|
||||
std::cout << "Invalid generator mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
|
||||
mindspore::Graph gen_graph;
|
||||
Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph);
|
||||
|
||||
Model gen_model;
|
||||
Status gen_ret = gen_model.Build(GraphCell(gen_graph), context);
|
||||
|
||||
if (gen_ret != kSuccess) {
|
||||
std::cout << "ERROR: Generator build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t n_attrs = 13;
|
||||
auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs);
|
||||
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||
std::map<double, double> costTime_map;
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
size_t size = all_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||
|
||||
auto img = std::make_shared<MSTensor>();
|
||||
std::shared_ptr<TensorTransform> decode(new Decode());
|
||||
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
|
||||
auto resizeShape = {FLAGS_image_height, FLAGS_image_width};
|
||||
std::shared_ptr<TensorTransform> resize(new Resize(resizeShape));
|
||||
std::shared_ptr<TensorTransform> normalize(
|
||||
new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
|
||||
Execute composeDecode({decode, resize, normalize, hwc2chw});
|
||||
auto image = ReadFileToTensor(all_files[i]);
|
||||
composeDecode(image, img.get());
|
||||
|
||||
for (size_t k = 0; k < n_attrs; ++k) {
|
||||
gettimeofday(&start, nullptr);
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
size_t index = i * n_attrs + k;
|
||||
std::cout << static_cast<color_rep_type>(img->DataType()) << std::endl;
|
||||
inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize());
|
||||
inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(),
|
||||
all_cfg[index].Data().get(), all_cfg[index].DataSize());
|
||||
Status gen_model_ret = gen_model.Predict(inputs, &outputs);
|
||||
if (gen_model_ret != kSuccess) {
|
||||
std::cout << "Generator inference " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
Denorm(&outputs);
|
||||
int pos = all_files[i].find('.');
|
||||
std::string fileName = all_files[i].substr(0, pos);
|
||||
WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
}
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "inc/utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
std::cout << fileName << std::endl;
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
|
||||
{static_cast<int64_t>(size)}, nullptr, size);
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::vector<std::string> split(std::string inputs) {
|
||||
std::vector<std::string> line;
|
||||
std::stringstream stream(inputs);
|
||||
std::string result;
|
||||
while ( stream >> result ) {
|
||||
line.push_back(result);
|
||||
}
|
||||
return line;
|
||||
}
|
||||
|
||||
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr) {
|
||||
std::vector<mindspore::MSTensor> res;
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << " open failed." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string n_images;
|
||||
std::string n_attrs;
|
||||
getline(ifs, n_images);
|
||||
getline(ifs, n_attrs);
|
||||
|
||||
auto n_images_ = std::stoi(n_images);
|
||||
auto n_attrs_ = std::stoi(n_attrs);
|
||||
*n_ptr = n_attrs_;
|
||||
std::cout << "Image number is " << n_images << std::endl;
|
||||
std::cout << "Attribute number is " << n_attrs << std::endl;
|
||||
|
||||
auto all_lines = n_images_ * n_attrs_;
|
||||
for (auto i = 0; i < all_lines; i++) {
|
||||
std::string val;
|
||||
getline(ifs, val);
|
||||
std::vector<std::string> val_split = split(val);
|
||||
void *data = malloc(sizeof(float)*n_attrs_);
|
||||
float *elements = reinterpret_cast<float *>(data);
|
||||
for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str());
|
||||
auto size = sizeof(float) * n_attrs_;
|
||||
mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32,
|
||||
{static_cast<int64_t>(size)}, nullptr, size);
|
||||
memcpy(buffer.MutableData(), elements, size);
|
||||
res.emplace_back(buffer);
|
||||
}
|
||||
ifs.close();
|
||||
return res;
|
||||
}
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
||||
|
||||
void Denorm(std::vector<MSTensor> *outputs) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
size_t outputSize = (*outputs)[i].DataSize();
|
||||
float* netOutput = reinterpret_cast<float *>((*outputs)[i].MutableData());
|
||||
size_t outputLen = outputSize / sizeof(float);
|
||||
|
||||
for (size_t j = 0; j < outputLen; ++j) {
|
||||
netOutput[j] = (netOutput[j] + 1) / 2 * 255;
|
||||
netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j];
|
||||
netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Entry point for testing AttGAN network"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
from mindspore import context, Tensor, ops
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Gen
|
||||
from src.cell import init_weights
|
||||
from src.data import check_attribute_conflict
|
||||
from src.data import get_loader, Custom
|
||||
from src.helpers import Progressbar
|
||||
from src.utils import resume_generator, denorm
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of Evaluation"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--num_test', dest='num_test', type=int)
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--custom_img', action='store_true')
|
||||
parser.add_argument('--custom_data', type=str, default='../data/custom')
|
||||
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
|
||||
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
|
||||
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
|
||||
return parser.parse_args(arg)
|
||||
|
||||
|
||||
args_ = parse()
|
||||
print(args_)
|
||||
|
||||
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.test_int = args_.test_int
|
||||
args.num_test = args_.num_test
|
||||
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||
args.custom_img = args_.custom_img
|
||||
args.custom_data = args_.custom_data
|
||||
args.custom_attr = args_.custom_attr
|
||||
args.shortcut_layers = args_.shortcut_layers
|
||||
args.inject_layers = args_.inject_layers
|
||||
args.n_attrs = len(args.attrs)
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
|
||||
print(args)
|
||||
|
||||
# Data loader
|
||||
if args.custom_img:
|
||||
output_path = join("output", args.experiment_name, "custom_testing")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
test_dataset = Custom(args.custom_data, args.custom_attr, args.attrs)
|
||||
test_len = len(test_dataset)
|
||||
else:
|
||||
output_path = join("output", args.experiment_name, "sample_testing")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
test_dataset = get_loader(args.data_path, args.attr_path,
|
||||
selected_attrs=args.attrs,
|
||||
mode="test"
|
||||
)
|
||||
test_len = len(test_dataset)
|
||||
dataset_column_names = ["image", "attr"]
|
||||
num_parallel_workers = 8
|
||||
ds = de.GeneratorDataset(test_dataset, column_names=dataset_column_names,
|
||||
num_parallel_workers=min(32, num_parallel_workers))
|
||||
ds = ds.batch(1, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=False)
|
||||
test_dataset_iter = ds.create_dict_iterator()
|
||||
|
||||
if args.num_test is None:
|
||||
print('Testing images:', test_len)
|
||||
else:
|
||||
print('Testing images:', min(test_len, args.num_test))
|
||||
|
||||
# Model loader
|
||||
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test')
|
||||
|
||||
# Initialize network
|
||||
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
load_param_into_net(gen, para_gen)
|
||||
|
||||
print("Network initializes successfully.")
|
||||
|
||||
progressbar = Progressbar()
|
||||
it = 0
|
||||
for data in test_dataset_iter:
|
||||
img_a = data["image"]
|
||||
att_a = data["attr"]
|
||||
if args.num_test is not None and it == args.num_test:
|
||||
break
|
||||
|
||||
att_a = Tensor(att_a, mstype.float32)
|
||||
att_b_list = [att_a]
|
||||
for i in range(args.n_attrs):
|
||||
clone = ops.Identity()
|
||||
tmp = clone(att_a)
|
||||
tmp[:, i] = 1 - tmp[:, i]
|
||||
tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
|
||||
att_b_list.append(tmp)
|
||||
|
||||
samples = [img_a]
|
||||
|
||||
for i, att_b in enumerate(att_b_list):
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
if i > 0:
|
||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||
samples.append(gen(img_a, att_b_, mode="enc-dec"))
|
||||
cat = ops.Concat(axis=3)
|
||||
samples = cat(samples).asnumpy()
|
||||
result = denorm(samples)
|
||||
result = np.reshape(result, (128, -1, 3))
|
||||
im = Image.fromarray(np.uint8(result))
|
||||
if args.custom_img:
|
||||
out_file = test_dataset.images[it]
|
||||
else:
|
||||
out_file = "{:06d}.jpg".format(it + 182638)
|
||||
im.save(output_path + '/' + out_file)
|
||||
print('Successful save image in ' + output_path + '/' + out_file)
|
||||
it += 1
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export file."""
|
||||
import argparse
|
||||
import json
|
||||
from os.path import join
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_param_into_net
|
||||
|
||||
from src.utils import resume_generator
|
||||
from src.attgan import Gen
|
||||
|
||||
parser = argparse.ArgumentParser(description='Attribute Edit')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
|
||||
args_ = parser.parse_args()
|
||||
print(args_)
|
||||
|
||||
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.device_id = args_.device_id
|
||||
args.batch_size = args_.batch_size
|
||||
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||
args.file_format = args_.file_format
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
gen = Gen(mode="test")
|
||||
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
load_param_into_net(gen, para_gen)
|
||||
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
|
||||
|
||||
Gen_file = f"attgan_mindir"
|
||||
export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of postprocess"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--bin_path', type=str, default='./result_Files/')
|
||||
parser.add_argument('--target_path', type=str, default='./result_Files/')
|
||||
return parser.parse_args(arg)
|
||||
|
||||
def load_bin_file(bin_file, shape=None, dtype="float32"):
|
||||
"""Load data from bin file"""
|
||||
data = np.fromfile(bin_file, dtype=dtype)
|
||||
if shape:
|
||||
data = np.reshape(data, shape)
|
||||
return data
|
||||
|
||||
def save_bin_to_image(data, out_name):
|
||||
"""Save bin file to image arrays"""
|
||||
image = np.transpose(data, (1, 2, 0))
|
||||
im = Image.fromarray(np.uint8(image))
|
||||
im.save(out_name)
|
||||
print("Successfully save image in " + out_name)
|
||||
|
||||
def scan_dir(bin_path):
|
||||
"""Scan directory"""
|
||||
out = os.listdir(bin_path)
|
||||
return out
|
||||
|
||||
def postprocess(bin_path):
|
||||
"""Post process bin file"""
|
||||
file_list = scan_dir(bin_path)
|
||||
for file in file_list:
|
||||
data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32")
|
||||
pos = file.find(".")
|
||||
file_name = file[0:pos] + "." + "jpg"
|
||||
outfile = os.path.join(args.target_path, file_name)
|
||||
save_bin_to_image(data, outfile)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse()
|
||||
postprocess(args.bin_path)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""pre process for 310 inference"""
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
selected_attrs = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of preprocess"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn')
|
||||
parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt')
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
|
||||
return parser.parse_args(arg)
|
||||
|
||||
args = parse()
|
||||
args.n_attrs = len(args.attrs)
|
||||
|
||||
def check_attribute_conflict(att_batch, att_name, att_names):
|
||||
"""Check Attributes"""
|
||||
def _set(att, att_name):
|
||||
if att_name in att_names:
|
||||
att[att_names.index(att_name)] = 0.0
|
||||
|
||||
att_id = att_names.index(att_name)
|
||||
for att in att_batch:
|
||||
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
|
||||
_set(att, 'Bangs')
|
||||
elif att_name == 'Bangs' and att[att_id] != 0:
|
||||
_set(att, 'Bald')
|
||||
_set(att, 'Receding_Hairline')
|
||||
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
|
||||
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
|
||||
for n in ['Straight_Hair', 'Wavy_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
|
||||
for n in ['Mustache', 'No_Beard']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
return att_batch
|
||||
|
||||
def read_cfg_file(attr_path):
|
||||
"""Read configuration from attribute file"""
|
||||
attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split()
|
||||
atts = [attr_list.index(att) + 1 for att in selected_attrs]
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0])
|
||||
labels = [labels] if attr_number == 1 else labels[0:]
|
||||
new_attr = []
|
||||
for index in range(attr_number):
|
||||
att = [np.asarray((labels[index] + 1) // 2)]
|
||||
new_attr.append(att)
|
||||
new_attr = np.array(new_attr)
|
||||
return new_attr, attr_number
|
||||
|
||||
def preprocess_cfg(attrs, numbers):
|
||||
"""Preprocess attribute file"""
|
||||
new_attr = []
|
||||
for index in range(numbers):
|
||||
attr = attrs[index]
|
||||
att_b_list = [attr]
|
||||
for i in range(args.n_attrs):
|
||||
tmp = attr.copy()
|
||||
tmp[:, i] = 1 - tmp[:, i]
|
||||
tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs)
|
||||
att_b_list.append(tmp)
|
||||
for i, att_b in enumerate(att_b_list):
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
if i > 0:
|
||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||
new_attr.append(att_b_)
|
||||
return new_attr
|
||||
|
||||
def write_cfg_file(attrs, numbers):
|
||||
"""Write attribute file"""
|
||||
cur_dir = os.getcwd()
|
||||
print(cur_dir)
|
||||
path = join(cur_dir, 'attrs.txt')
|
||||
with open(path, "w") as f:
|
||||
f.writelines(str(numbers))
|
||||
f.writelines("\n")
|
||||
f.writelines(str(args.n_attrs))
|
||||
f.writelines("\n")
|
||||
counts = numbers * args.n_attrs
|
||||
for index in range(counts):
|
||||
attrs_list = attrs[index][0]
|
||||
new_attrs_list = ["%s" % x for x in attrs_list]
|
||||
sequence = " ".join(new_attrs_list)
|
||||
f.writelines(sequence)
|
||||
f.writelines("\n")
|
||||
print("Generate cfg file successfully.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if args.attrs_path is None:
|
||||
print("Path is not correct!")
|
||||
attributes, n_images = read_cfg_file(args.attrs_path)
|
||||
new_attrs = preprocess_cfg(attributes, n_images)
|
||||
write_cfg_file(new_attrs, n_images)
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [ATTR_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$1
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_SIZE=8
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
echo "hccl connect time out has changed to 600 second"
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "Start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
taskset -c $cmdopt python train.py \
|
||||
--img_size 128 \
|
||||
--shortcut_layers 1 \
|
||||
--inject_layers 1 \
|
||||
--experiment_name 128_shortcut1_inject1_none \
|
||||
--data_path $data_path \
|
||||
--attr_path $attr_path \
|
||||
--run_distribute 1 > ./scripts/LOG$i/log.txt 2>&1 &
|
||||
cd scripts
|
||||
done
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
experiment_name=$1
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
gen_ckpt_name=$4
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "The number of logical core" $cores
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf EVAL_LOG
|
||||
mkdir ./EVAL_LOG
|
||||
cd ./EVAL_LOG || exit
|
||||
echo "Start training for rank 0, device 0, directory is EVAL_LOG"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
|
||||
python eval.py \
|
||||
--experiment_name $experiment_name \
|
||||
--test_int 1.0 \
|
||||
--custom_data $data_path \
|
||||
--custom_attr $attr_path \
|
||||
--custom_img \
|
||||
--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
|
@ -0,0 +1,128 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
gen_model=$(get_real_path $1)
|
||||
attr_path=$(get_real_path $2)
|
||||
data_path=$(get_real_path $3)
|
||||
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "generator mindir name: "$gen_model
|
||||
echo "attribute file path: "$attr_path
|
||||
echo "dataset path: "$data_path
|
||||
echo "need preprocess: "$need_preprocess
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
echo "Start to preprocess attr file..."
|
||||
python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log
|
||||
echo "Attribute file generates successfully!"
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
echo "Start to compile source code..."
|
||||
cd ../ascend310_infer || exit
|
||||
bash build.sh &> build.log
|
||||
echo "Compile successfully."
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
echo "Start to execute inference..."
|
||||
../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log
|
||||
}
|
||||
|
||||
function postprocess_data()
|
||||
{
|
||||
echo "Start to postprocess image file..."
|
||||
python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/"
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess attrs failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
postprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "postprocess images failed"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [EXPERIMENT_NAME] [DATA_PATH] [ATTR_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
experiment_name=$1
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf LOG
|
||||
mkdir ./LOG
|
||||
cd ./LOG || exit
|
||||
echo "Start training for rank 0, device 0, directory is LOG"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
|
||||
python train.py \
|
||||
--img_size 128 \
|
||||
--shortcut_layers 1 \
|
||||
--inject_layers 1 \
|
||||
--experiment_name $experiment_name \
|
||||
--data_path $data_path \
|
||||
--attr_path $attr_path > ./scripts/LOG/log.txt 2>&1 &
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AttGAN Network Topology"""
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn
|
||||
|
||||
from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
|
||||
|
||||
# Image size 128 x 128
|
||||
MAX_DIM = 64 * 16
|
||||
|
||||
class Gen(nn.Cell):
|
||||
"""Generator"""
|
||||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
|
||||
dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu",
|
||||
n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"):
|
||||
super().__init__()
|
||||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
||||
self.inject_layers = min(inject_layers, dec_layers - 1)
|
||||
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
||||
|
||||
layers = []
|
||||
n_in = 3
|
||||
for i in range(enc_layers):
|
||||
n_out = min(enc_dim * 2 ** i, MAX_DIM)
|
||||
layers += [Conv2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=enc_norm_fn, acti_fn=enc_acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
self.enc_layers = nn.CellList(layers)
|
||||
|
||||
layers = []
|
||||
n_in = n_in + n_attrs # 1024 + 13
|
||||
for i in range(dec_layers):
|
||||
if i < dec_layers - 1:
|
||||
n_out = min(dec_dim * 2 ** (dec_layers - i - 1), MAX_DIM)
|
||||
layers += [ConvTranspose2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=dec_norm_fn, acti_fn=dec_acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
n_in = n_in + n_in // 2 if self.shortcut_layers > i else n_in
|
||||
n_in = n_in + n_attrs if self.inject_layers > i else n_in
|
||||
else:
|
||||
layers += [ConvTranspose2dBlock(
|
||||
n_in, 3, (4, 4), stride=2, padding=1, norm_fn='none', acti_fn='tanh', mode=mode
|
||||
)]
|
||||
self.dec_layers = nn.CellList(layers)
|
||||
|
||||
self.view = P.Reshape()
|
||||
self.repeat = P.Tile()
|
||||
self.cat = P.Concat(1)
|
||||
|
||||
def encoder(self, x):
|
||||
"""Encoder construct"""
|
||||
z = x
|
||||
zs = []
|
||||
for layer in self.enc_layers:
|
||||
z = layer(z)
|
||||
zs.append(z)
|
||||
return zs
|
||||
|
||||
def decoder(self, zs, a):
|
||||
"""Decoder construct"""
|
||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||
multiples = (1, 1, self.f_size, self.f_size)
|
||||
a_tile = self.repeat(a_tile, multiples)
|
||||
|
||||
z = self.cat((zs[-1], a_tile))
|
||||
i = 0
|
||||
for layer in self.dec_layers:
|
||||
z = layer(z)
|
||||
if self.shortcut_layers > i:
|
||||
z = self.cat((z, zs[len(self.dec_layers) - 2 - i]))
|
||||
if self.inject_layers > i:
|
||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||
multiples = (1, 1, self.f_size * 2 ** (i + 1), self.f_size * 2 ** (i + 1))
|
||||
a_tile = self.repeat(a_tile, multiples)
|
||||
z = self.cat((z, a_tile))
|
||||
i = i + 1
|
||||
return z
|
||||
|
||||
def construct(self, x, a=None, mode="enc-dec"):
|
||||
result = None
|
||||
if mode == "enc-dec":
|
||||
out = self.encoder(x)
|
||||
result = self.decoder(out, a)
|
||||
if mode == "enc":
|
||||
result = self.encoder(x)
|
||||
if mode == "dec":
|
||||
result = self.decoder(x, a)
|
||||
return result
|
||||
|
||||
class Dis(nn.Cell):
|
||||
"""Discriminator"""
|
||||
def __init__(self, dim=64, norm_fn='none', acti_fn='lrelu',
|
||||
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128, mode='test'):
|
||||
super().__init__()
|
||||
self.f_size = img_size // 2 ** n_layers
|
||||
|
||||
layers = []
|
||||
n_in = 3
|
||||
for i in range(n_layers):
|
||||
n_out = min(dim * 2 ** i, MAX_DIM)
|
||||
layers += [Conv2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=norm_fn, acti_fn=acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
self.conv = nn.SequentialCell(layers)
|
||||
self.fc_adv = nn.SequentialCell(
|
||||
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
|
||||
LinearBlock(fc_dim, 1, 'none', 'none', mode)])
|
||||
|
||||
self.fc_cls = nn.SequentialCell(
|
||||
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
|
||||
LinearBlock(fc_dim, 13, 'none', 'none', mode)])
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
h = self.conv(x)
|
||||
view = P.Reshape()
|
||||
h = view(h, (h.shape[0], -1))
|
||||
return self.fc_adv(h), self.fc_cls(h)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network Component"""
|
||||
|
||||
from mindspore import nn
|
||||
|
||||
def add_normalization_1d(layers, fn, n_out, mode='test'):
|
||||
if fn == "none":
|
||||
pass
|
||||
elif fn == "batchnorm":
|
||||
layers.append(nn.BatchNorm1d(n_out, use_batch_statistics=(mode == 'train')))
|
||||
elif fn == "instancenorm":
|
||||
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
|
||||
else:
|
||||
raise Exception('Unsupported normalization: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
def add_normalization_2d(layers, fn, n_out, mode='test'):
|
||||
if fn == 'none':
|
||||
pass
|
||||
elif fn == 'batchnorm':
|
||||
layers.append(nn.BatchNorm2d(n_out, use_batch_statistics=(mode == 'train')))
|
||||
elif fn == "instancenorm":
|
||||
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
|
||||
else:
|
||||
raise Exception('Unsupported normalization: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
def add_activation(layers, fn):
|
||||
"""Add Activation"""
|
||||
if fn == "none":
|
||||
pass
|
||||
elif fn == "relu":
|
||||
layers.append(nn.ReLU())
|
||||
elif fn == "lrelu":
|
||||
layers.append(nn.LeakyReLU(alpha=0.01))
|
||||
elif fn == "sigmoid":
|
||||
layers.append(nn.Sigmoid())
|
||||
elif fn == "tanh":
|
||||
layers.append(nn.Tanh())
|
||||
else:
|
||||
raise Exception('Unsupported activation function: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
class LinearBlock(nn.Cell):
|
||||
"""Linear Block"""
|
||||
def __init__(self, n_in, n_out, norm_fn="none", acti_fn="none", mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Dense(n_in, n_out, has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_1d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Cell):
|
||||
"""Convolution Block"""
|
||||
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
|
||||
norm_fn=None, acti_fn=None, mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Conv2d(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
|
||||
has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class ConvTranspose2dBlock(nn.Cell):
|
||||
"""Transpose Convolution Block"""
|
||||
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
|
||||
norm_fn=None, acti_fn=None, mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Conv2dTranspose(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
|
||||
has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cell Definition"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn, ops
|
||||
from mindspore.common import initializer as init, set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'KaimingUniform':
|
||||
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, (nn.GroupNorm, nn.BatchNorm2d)):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class GenWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return generator loss
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super().__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
_, g_loss, _, _, _, = self.network(img_a, att_a, att_a_, att_b, att_b_)
|
||||
return g_loss
|
||||
|
||||
|
||||
class DisWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return discriminator loss
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super().__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
d_loss, _, _, _, _ = self.network(img_a, att_a, att_a_, att_b, att_b_)
|
||||
return d_loss
|
||||
|
||||
|
||||
class TrainOneStepCellGen(nn.Cell):
|
||||
"""Encapsulation class of AttGAN generator network training."""
|
||||
|
||||
def __init__(self, generator, optimizer, sens=1.0):
|
||||
super().__init__()
|
||||
self.optimizer = optimizer
|
||||
self.generator = generator
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = GenWithLossCell(generator)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
weights = self.weights
|
||||
_, loss, gf_loss, gc_loss, gr_loss = self.generator(img_a, att_a, att_a_, att_b, att_b_)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
|
||||
|
||||
|
||||
class TrainOneStepCellDis(nn.Cell):
|
||||
"""Encapsulation class of AttGAN discriminator network training."""
|
||||
|
||||
def __init__(self, discriminator, optimizer, sens=1.0):
|
||||
super().__init__()
|
||||
self.optimizer = optimizer
|
||||
self.discriminator = discriminator
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = DisWithLossCell(discriminator)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
weights = self.weights
|
||||
loss, d_real_loss, d_fake_loss, dc_loss, df_gp = self.discriminator(img_a, att_a, att_a_, att_b, att_b_)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" DataLoader: CelebA"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore.dataset.transforms import py_transforms
|
||||
|
||||
from src.utils import DistributedSampler
|
||||
|
||||
|
||||
class Custom:
|
||||
"""
|
||||
Custom data loader
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, attr_path, selected_attrs):
|
||||
self.data_path = data_path
|
||||
|
||||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
|
||||
atts = [att_list.index(att) + 1 for att in selected_attrs]
|
||||
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
transform = [py_vision.ToPIL()]
|
||||
transform.append(py_vision.Resize([128, 128]))
|
||||
transform.append(py_vision.ToTensor())
|
||||
transform.append(py_vision.Normalize(mean=mean, std=std))
|
||||
transform = py_transforms.Compose(transform)
|
||||
self.transform = transform
|
||||
self.images = np.array([images]) if images.size == 1 else images[0:]
|
||||
self.labels = np.array([labels]) if images.size == 1 else labels[0:]
|
||||
self.length = len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
|
||||
att = np.asarray((self.labels[index] + 1) // 2)
|
||||
image = np.squeeze(self.transform(image))
|
||||
return image, att
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class CelebA:
|
||||
"""
|
||||
CelebA dataset
|
||||
Input:
|
||||
data_path: Image Path
|
||||
attr_path: Attr_list Path
|
||||
image_size: Image Size
|
||||
mode: Train or Test
|
||||
selected_attrs: selected attributes
|
||||
transform: Image Processing
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs, transform, split_point=182000):
|
||||
super().__init__()
|
||||
self.data_path = data_path
|
||||
self.transform = transform
|
||||
self.img_size = image_size
|
||||
|
||||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
|
||||
atts = [att_list.index(att) + 1 for att in selected_attrs]
|
||||
|
||||
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
|
||||
if mode == "train":
|
||||
self.images = images[:split_point]
|
||||
self.labels = labels[:split_point]
|
||||
if mode == "test":
|
||||
self.images = images[split_point:]
|
||||
self.labels = labels[split_point:]
|
||||
|
||||
self.length = len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
|
||||
att = np.asarray((self.labels[index] + 1) // 2)
|
||||
image = np.squeeze(self.transform(image))
|
||||
return image, att
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def get_loader(data_root, attr_path, selected_attrs, crop_size=170, image_size=128, mode="train", split_point=182000):
|
||||
"""Build and return dataloader"""
|
||||
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
transform = [py_vision.ToPIL()]
|
||||
transform.append(py_vision.CenterCrop((crop_size, crop_size)))
|
||||
transform.append(py_vision.Resize([image_size, image_size]))
|
||||
transform.append(py_vision.ToTensor())
|
||||
transform.append(py_vision.Normalize(mean=mean, std=std))
|
||||
transform = py_transforms.Compose(transform)
|
||||
|
||||
dataset = CelebA(data_root, attr_path, image_size, mode, selected_attrs, transform, split_point=split_point)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def data_loader(img_path, attr_path, selected_attrs, mode="train", batch_size=1, device_num=1, rank=0, shuffle=True,
|
||||
split_point=182000):
|
||||
"""CelebA data loader"""
|
||||
num_parallel_workers = 8
|
||||
|
||||
dataset_loader = get_loader(img_path, attr_path, selected_attrs, mode=mode, split_point=split_point)
|
||||
length_dataset = len(dataset_loader)
|
||||
|
||||
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
|
||||
dataset_column_names = ["image", "attr"]
|
||||
|
||||
if device_num != 8:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
|
||||
num_parallel_workers=min(32, num_parallel_workers),
|
||||
sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
|
||||
else:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
|
||||
|
||||
# ds = ds.repeat(200)
|
||||
|
||||
return ds, length_dataset
|
||||
|
||||
|
||||
def check_attribute_conflict(att_batch, att_name, att_names):
|
||||
"""Check Attributes"""
|
||||
def _set(att, att_name):
|
||||
if att_name in att_names:
|
||||
att[att_names.index(att_name)] = 0.0
|
||||
|
||||
att_id = att_names.index(att_name)
|
||||
for att in att_batch:
|
||||
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
|
||||
_set(att, 'Bangs')
|
||||
elif att_name == 'Bangs' and att[att_id] != 0:
|
||||
_set(att, 'Bald')
|
||||
_set(att, 'Receding_Hairline')
|
||||
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
|
||||
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
|
||||
for n in ['Straight_Hair', 'Wavy_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
|
||||
for n in ['Mustache', 'No_Beard']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
return att_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
attrs_default = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to test')
|
||||
parser.add_argument('--data_path', dest='data_path', type=str, required=True)
|
||||
parser.add_argument('--attr_path', dest='attr_path', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
####### Test CelebA #######
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
dataset_ce, length_ce = data_loader(args.data_path, args.attr_path, attrs_default, mode="train")
|
||||
i = 0
|
||||
for data in dataset_ce.create_dict_iterator():
|
||||
print('Number:', i, 'Value:', data["attr"], 'Type:', type(data["attr"]))
|
||||
i += 1
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================s
|
||||
"""Helper functions for training"""
|
||||
|
||||
import datetime
|
||||
import platform
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def name_experiment(prefix="", suffix=""):
|
||||
experiment_name = datetime.datetime.now().strftime('%b%d_%H-%M-%S_') + platform.node()
|
||||
if prefix is not None and prefix != '':
|
||||
experiment_name = prefix + '_' + experiment_name
|
||||
if suffix is not None and suffix != '':
|
||||
experiment_name = experiment_name + '_' + suffix
|
||||
return experiment_name
|
||||
|
||||
|
||||
class Progressbar():
|
||||
"""Progress Bar"""
|
||||
|
||||
def __init__(self):
|
||||
self.p = None
|
||||
|
||||
def __call__(self, iterable, length):
|
||||
self.p = tqdm(iterable, total=length)
|
||||
return self.p
|
||||
|
||||
def say(self, **kwargs):
|
||||
if self.p is not None:
|
||||
self.p.set_postfix(**kwargs)
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================s
|
||||
"""Loss Computation of Generator and Discriminator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import nn, Tensor, ops
|
||||
from mindspore.ops import constexpr
|
||||
|
||||
|
||||
class ClassificationLoss(nn.Cell):
|
||||
"""Define classification loss for AttGAN"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bce_loss = P.BinaryCrossEntropy(reduction='sum')
|
||||
|
||||
def construct(self, pred, label):
|
||||
weight = ops.Ones()(pred.shape, mindspore.float32)
|
||||
pred_ = P.Sigmoid()(pred)
|
||||
x = self.bce_loss(pred_, label, weight) / pred.shape[0]
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_tensor(batch_size):
|
||||
np_array = np.random.randn(batch_size, 1, 1, 1)
|
||||
return Tensor(np_array, mindspore.float32)
|
||||
|
||||
|
||||
class GradientWithInput(nn.Cell):
|
||||
"""Get Discriminator Gradient with Input"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super().__init__()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.discriminator = discriminator
|
||||
self.discriminator.set_train(mode=True)
|
||||
|
||||
def construct(self, interpolates):
|
||||
decision_interpolate, _ = self.discriminator(interpolates)
|
||||
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
|
||||
return decision_interpolate
|
||||
|
||||
|
||||
class WGANGPGradientPenalty(nn.Cell):
|
||||
"""Define WGAN loss for AttGAN"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super().__init__()
|
||||
self.gradient_op = ops.GradOperation()
|
||||
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
|
||||
|
||||
self.sqrt = ops.Sqrt()
|
||||
self.discriminator = discriminator
|
||||
self.GradientWithInput = GradientWithInput(discriminator)
|
||||
|
||||
def construct(self, x_real, x_fake):
|
||||
"""get gradient penalty"""
|
||||
batch_size = x_real.shape[0]
|
||||
alpha = generate_tensor(batch_size)
|
||||
alpha = alpha.expand_as(x_real)
|
||||
x_fake = ops.functional.stop_gradient(x_fake)
|
||||
x_hat = x_real + alpha * (x_fake - x_real)
|
||||
|
||||
gradient = self.gradient_op(self.GradientWithInput)(x_hat)
|
||||
gradient_1 = ops.reshape(gradient, (batch_size, -1))
|
||||
gradient_1 = self.sqrt(self.reduce_sum(gradient_1 * gradient_1, 1))
|
||||
gradient_penalty = self.reduce_sum((gradient_1 - 1.0) ** 2) / x_real.shape[0]
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
class GenLoss(nn.Cell):
|
||||
"""Define total Generator loss"""
|
||||
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
|
||||
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
|
||||
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
|
||||
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.cls_loss = nn.BCEWithLogitsLoss()
|
||||
self.rec_loss = nn.L1Loss("mean")
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get generator loss"""
|
||||
# generate
|
||||
zs_a = self.generator(img_a, mode="enc")
|
||||
img_fake = self.generator(zs_a, att_b_, mode="dec")
|
||||
img_recon = self.generator(zs_a, att_a_, mode="dec")
|
||||
|
||||
# discriminate
|
||||
d_fake, dc_fake = self.discriminator(img_fake)
|
||||
|
||||
# generator loss
|
||||
gf_loss = - self.cyc_loss(d_fake)
|
||||
gc_loss = self.cls_loss(dc_fake, att_b)
|
||||
gr_loss = self.rec_loss(img_a, img_recon)
|
||||
|
||||
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
|
||||
|
||||
return (img_fake, g_loss, gf_loss, gc_loss, gr_loss)
|
||||
|
||||
|
||||
class DisLoss(nn.Cell):
|
||||
"""Define total discriminator loss"""
|
||||
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.cls_loss = nn.BCEWithLogitsLoss()
|
||||
self.WGANLoss = WGANGPGradientPenalty(discriminator)
|
||||
|
||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
|
||||
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
|
||||
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get discriminator loss"""
|
||||
# generate
|
||||
img_fake = self.generator(img_a, att_b_, mode="enc-dec")
|
||||
|
||||
# discriminate
|
||||
d_real, dc_real = self.discriminator(img_a)
|
||||
d_fake, _ = self.discriminator(img_fake)
|
||||
|
||||
# discriminator losses
|
||||
d_real_loss = - self.cyc_loss(d_real)
|
||||
d_fake_loss = self.cyc_loss(d_fake)
|
||||
df_loss = d_real_loss + d_fake_loss
|
||||
|
||||
df_gp = self.WGANLoss(img_a, img_fake)
|
||||
|
||||
dc_loss = self.cls_loss(dc_real, att_a)
|
||||
|
||||
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss
|
||||
|
||||
return (d_loss, d_real_loss, d_fake_loss, dc_loss, df_gp)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Helper functions"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from mindspore import load_checkpoint
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.epoch = 0
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank: self.total_size: self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def resume_generator(args, generator, gen_ckpt_name):
|
||||
"""Restore the trained generator"""
|
||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||
generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name)
|
||||
param_generator = load_checkpoint(generator_path, generator)
|
||||
|
||||
return param_generator
|
||||
|
||||
def resume_discriminator(args, discriminator, dis_ckpt_name):
|
||||
"""Restore the trained discriminator"""
|
||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||
discriminator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dis_ckpt_name)
|
||||
param_discriminator = load_checkpoint(discriminator_path, discriminator)
|
||||
|
||||
return param_discriminator
|
||||
|
||||
def denorm(x):
|
||||
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
|
||||
image_numpy = np.clip(image_numpy, 0, 255)
|
||||
return image_numpy
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Entry point for training AttGAN network"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import nn
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Gen, Dis
|
||||
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
|
||||
from src.data import data_loader
|
||||
from src.helpers import Progressbar
|
||||
from src.loss import GenLoss, DisLoss
|
||||
from src.utils import resume_generator, resume_discriminator
|
||||
|
||||
attrs_default = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of Model"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to learn')
|
||||
parser.add_argument('--data', dest='data', type=str, choices=['CelebA'], default='CelebA')
|
||||
parser.add_argument('--data_path', dest='data_path', type=str, default='./data/img_align_celeba')
|
||||
parser.add_argument('--attr_path', dest='attr_path', type=str, default='./data/list_attr_celeba.txt')
|
||||
|
||||
parser.add_argument('--img_size', dest='img_size', type=int, default=128)
|
||||
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
|
||||
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
|
||||
|
||||
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64)
|
||||
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024)
|
||||
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5)
|
||||
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5)
|
||||
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5)
|
||||
parser.add_argument('--enc_norm', dest='enc_norm', type=str, default='batchnorm')
|
||||
parser.add_argument('--dec_norm', dest='dec_norm', type=str, default='batchnorm')
|
||||
parser.add_argument('--dis_norm', dest='dis_norm', type=str, default='instancenorm')
|
||||
parser.add_argument('--dis_fc_norm', dest='dis_fc_norm', type=str, default='none')
|
||||
parser.add_argument('--enc_acti', dest='enc_acti', type=str, default='lrelu')
|
||||
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu')
|
||||
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu')
|
||||
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu')
|
||||
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0)
|
||||
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0)
|
||||
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0)
|
||||
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0)
|
||||
|
||||
parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='# of epochs')
|
||||
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32)
|
||||
parser.add_argument('--num_workers', dest='num_workers', type=int, default=16)
|
||||
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
|
||||
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5)
|
||||
parser.add_argument('--beta2', dest='beta2', type=float, default=0.999)
|
||||
parser.add_argument('--n_d', dest='n_d', type=int, default=5, help='# of d updates per g update')
|
||||
parser.add_argument('--split_point', dest='split_point', type=int, default=182000, help='# of dataset split point')
|
||||
|
||||
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--save_interval', dest='save_interval', type=int, default=500)
|
||||
parser.add_argument('--experiment_name', dest='experiment_name',
|
||||
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
|
||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||
parser.add_argument('--resume_model', action='store_true')
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dis_ckpt_name', type=str, default='')
|
||||
|
||||
return parser.parse_args(arg)
|
||||
|
||||
|
||||
args = parse()
|
||||
print(args)
|
||||
|
||||
args.lr_base = args.lr
|
||||
args.n_attrs = len(args.attrs)
|
||||
|
||||
# initialize environment
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if args.run_distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
print(device_num)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = get_rank()
|
||||
else:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
rank = 0
|
||||
|
||||
print("Initialize successful!")
|
||||
|
||||
os.makedirs(join('output', args.experiment_name), exist_ok=True)
|
||||
os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True)
|
||||
with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f:
|
||||
f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Define dataloader
|
||||
train_dataset, train_length = data_loader(img_path=args.data_path,
|
||||
attr_path=args.attr_path,
|
||||
selected_attrs=args.attrs,
|
||||
mode="train",
|
||||
batch_size=args.batch_size,
|
||||
device_num=device_num,
|
||||
shuffle=True,
|
||||
split_point=args.split_point)
|
||||
train_loader = train_dataset.create_dict_iterator()
|
||||
|
||||
print('Training images:', train_length)
|
||||
|
||||
# Define network
|
||||
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train')
|
||||
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
|
||||
args.dis_layers, args.img_size, mode='train')
|
||||
|
||||
# Initialize network
|
||||
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(dis, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
# Resume from checkpoint
|
||||
if args.resume_model:
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
|
||||
load_param_into_net(gen, para_gen)
|
||||
load_param_into_net(dis, para_dis)
|
||||
|
||||
# Define network with loss
|
||||
G_loss_cell = GenLoss(args, gen, dis)
|
||||
D_loss_cell = DisLoss(args, gen, dis)
|
||||
|
||||
# Define Optimizer
|
||||
optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
|
||||
# Define One Step Train
|
||||
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizer_G)
|
||||
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizer_D)
|
||||
|
||||
# Train
|
||||
G_trainOneStep.set_train(True)
|
||||
D_trainOneStep.set_train(True)
|
||||
|
||||
print("Start Training")
|
||||
|
||||
train_iter = train_length // args.batch_size
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_interval)
|
||||
|
||||
if rank == 0:
|
||||
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
|
||||
ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator')
|
||||
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
|
||||
|
||||
cb_params_gen = _InternalCallbackParam()
|
||||
cb_params_gen.train_network = gen
|
||||
cb_params_gen.cur_epoch_num = 0
|
||||
gen_run_context = RunContext(cb_params_gen)
|
||||
ckpt_cb_gen.begin(gen_run_context)
|
||||
|
||||
cb_params_dis = _InternalCallbackParam()
|
||||
cb_params_dis.train_network = dis
|
||||
cb_params_dis.cur_epoch_num = 0
|
||||
dis_run_context = RunContext(cb_params_dis)
|
||||
ckpt_cb_dis.begin(dis_run_context)
|
||||
|
||||
# Initialize Progressbar
|
||||
progressbar = Progressbar()
|
||||
|
||||
it = 0
|
||||
for epoch in range(args.epochs):
|
||||
|
||||
for data in progressbar(train_loader, train_iter):
|
||||
img_a = data["image"]
|
||||
att_a = data["attr"]
|
||||
att_a = att_a.asnumpy()
|
||||
att_b = np.random.permutation(att_a)
|
||||
|
||||
att_a_ = (att_a * 2 - 1) * args.thres_int
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
|
||||
att_a = Tensor(att_a, mstype.float32)
|
||||
att_a_ = Tensor(att_a_, mstype.float32)
|
||||
att_b = Tensor(att_b, mstype.float32)
|
||||
att_b_ = Tensor(att_b_, mstype.float32)
|
||||
|
||||
if (it + 1) % (args.n_d + 1) != 0:
|
||||
d_out, d_real_loss, d_fake_loss, dc_loss, df_gp = D_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
|
||||
else:
|
||||
g_out, gf_loss, gc_loss, gr_loss = G_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
|
||||
progressbar.say(epoch=epoch, iter=it + 1, d_loss=d_out, g_loss=g_out, gf_loss=gf_loss, gc_loss=gc_loss,
|
||||
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
|
||||
|
||||
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
|
||||
cb_params_gen.cur_epoch_num = epoch + 1
|
||||
cb_params_dis.cur_epoch_num = epoch + 1
|
||||
cb_params_gen.cur_step_num = it + 1
|
||||
cb_params_dis.cur_step_num = it + 1
|
||||
cb_params_gen.batch_num = it + 2
|
||||
cb_params_dis.batch_num = it + 2
|
||||
ckpt_cb_gen.step_end(gen_run_context)
|
||||
ckpt_cb_dis.step_end(dis_run_context)
|
||||
it += 1
|
Loading…
Reference in New Issue