!22589 AttGAN commit r1.3

Merge pull request !22589 from MR.D/dev1.3
This commit is contained in:
i-robot 2021-08-30 12:29:53 +00:00 committed by Gitee
commit 65dabb58ef
23 changed files with 2467 additions and 0 deletions

View File

@ -0,0 +1,245 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [AttGAN描述](#AttGAN描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
- [在Ascend310执行推理](#在Ascend310执行推理)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [CelebA上的AttGAN](#CelebA上的AttGAN)
- [推理性能](#推理性能)
- [CelebA上的AttGAN](#CelebA上的AttGAN)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# AttGAN描述
AttGAN指的是AttGAN: Facial Attribute Editing by Only Changing What You Want, 这个网络的特点是可以在不影响面部其它属性的情况下修改人脸属性。
[论文](https://arxiv.org/abs/1711.10678)[Zhenliang He](https://github.com/LynnHo/AttGAN-Tensorflow), [Wangmeng Zuo](https://github.com/LynnHo/AttGAN-Tensorflow), [Meina Kan](https://github.com/LynnHo/AttGAN-Tensorflow), [Shiguang Shan](https://github.com/LynnHo/AttGAN-Tensorflow), [Xilin Chen](https://github.com/LynnHo/AttGAN-Tensorflow), et al. AttGAN: Facial Attribute Editing by Only Changing What You Want[C]// 2017 CVPR. IEEE
# 模型架构
整个网络结构由一个生成器和一个判别器构成生成器由编码器和解码器构成。该模型移除了严格的attribute-independent约束仅需要通过attribute classification来保证正确地修改属性同时整合了attribute classification constraint、adversarial learning和reconstruction learning具有较好的修改面部属性的效果。
# 数据集
使用的数据集: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据集,拥有超过 200K 的名人图像,每个图像有 40 个属性注释。 CelebA 多样性大、数量多、注释丰富,包括
- 10,177 number of identities,
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
该数据集可用作以下计算机视觉任务的训练和测试集:人脸属性识别、人脸检测以及人脸编辑和合成。
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 运行训练示例
export DEVICE_ID=0
export RANK_SIZE=1
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
OR
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt
# 运行分布式训练示例
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
# 运行评估示例
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
OR
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
`/output/{experiment_name}/checkpoint`目录下执行脚本时需要将检查点文件Generator的名称作为参数传入。
[注意] 以上路径均应设置为绝对路径
# 脚本说明
## 脚本及样例代码
```text
.
└─ cv
└─ AttGAN
├── ascend310_infer # 310推理目录
├── scripts
├──run_distribute_train.sh # 分布式训练的shell脚本
├──run_single_train.sh # 单卡训练的shell脚本
├──run_eval.sh # 评估脚本
├──run_infer_310.sh # 推理脚本
├─ src
├─ __init__.py # 初始化文件
├─ block.py # 基础cell
├─ attgan.py # 生成网络和判别网络
├─ utils.py # 辅助函数
├─ cell.py # loss网络wrapper
├─ data.py # 数据加载
├─ helpers.py # 进度条显示
├─ loss.py # loss计算
├─ eval.py # 测试脚本
├─ train.py # 训练脚本
├─ export.py # MINDIR模型导出脚本
├─ preprocess.py # 310推理预处理脚本
├─ postprocess.py # 310推理后处理脚本
└─ README_CN.md # AttGAN的文件描述
```
该脚本可以修改13种属性分别为Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young。
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python train.py --img_size 128 --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
```
训练结束后当前目录下会生成output目录在该目录下会根据你设置的experiment_name参数生成相应的子目录训练时的参数保存在该子目录下的setting.txt文件中checkpoint文件保存在`output/experiment_name/rank0`下。如果需要生成setting.txt文件只需要执行一次train.py文件即可此时脚本会根据设定的参数生成对应的setting.txt文件。该文件会被用于推理脚本。
### 分布式训练
- Ascend处理器环境运行
```bash
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
```
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
## 评估过程
### 评估
- 在Ascend环境运行时评估自定义数据集
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后需要将自定义图片目录和属性编辑文件作为参数传入测试脚本分别对应custom_data以及custom_attr。
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数)
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
```
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
## 推理过程
### 导出MindIR
```shell
python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT]
```
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
`experiment_name` 是output目录下的存放结果的文件夹的名称此参数用于帮助export寻找参数
脚本会在当前目录下生成对应的MINDIR文件。
### 在Ascend310执行推理
在执行推理前必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑
```bash
bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
```
- `MINDIR_PATH` MINDIR文件的路径
- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径
- `DATA_PATH` 需要进行推理的数据集目录图像格式应当为jpg
- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理可以在y或者n中选择如果选择y表示进行预处理在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间)
- `DEVICE_ID` 可选默认值为0.
[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件第一行为要推理的图片数目第二行为要编辑的属性接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。
### 结果
推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果是否对该属性进行编辑根据属性编辑文件的内容决定。
# 模型描述
## 性能
### 评估性能
#### CelebA上的AttGAN
| 参数 | Ascend 910 |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | AttGAN |
| 资源 | Ascend |
| 上传日期 | 06/30/2021 (month/day/year) |
| MindSpore版本 | 1.2.0 |
| 数据集 | CelebA |
| 训练参数 | batch_size=32, lr=0.0002 |
| 优化器 | Adam |
| 生成器输出 | image |
| 速度 | 5.56 step/s |
| 脚本 | [AttGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/AttGAN) |
### 推理性能
#### CelebA上的AttGAN
| 参数 | Ascend 910 |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | AttGAN |
| 资源 | Ascend |
| 上传日期 | 06/30/2021 (month/day/year) |
| MindSpore版本 | 1.2.0 |
| 数据集 | CelebA |
| 推理参数 | batch_size=1 |
| 输出 | image |
推理完成后可以获得对原图像进行属性编辑后的图片slide.
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
DIR *OpenDir(std::string_view dirName);
void Denorm(std::vector<mindspore::MSTensor> *outputs);
std::string RealPath(std::string_view path);
std::vector<std::string> GetAllFiles(std::string_view dirName);
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,149 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Decode;
using color_rep_type = std::underlying_type<mindspore::DataType>::type;
DEFINE_string(gen_mindir_path, "", "generator mindir path");
DEFINE_string(dataset_path, "", "dataset path");
DEFINE_string(attr_file_path, "", "attribute file path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_int32(image_height, 128, "image height");
DEFINE_int32(image_width, 128, "image width");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_gen_mindir_path).empty()) {
std::cout << "Invalid generator mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph gen_graph;
Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph);
Model gen_model;
Status gen_ret = gen_model.Build(GraphCell(gen_graph), context);
if (gen_ret != kSuccess) {
std::cout << "ERROR: Generator build failed." << std::endl;
return 1;
}
size_t n_attrs = 13;
auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs);
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
double startTimeMs;
double endTimeMs;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto img = std::make_shared<MSTensor>();
std::shared_ptr<TensorTransform> decode(new Decode());
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
auto resizeShape = {FLAGS_image_height, FLAGS_image_width};
std::shared_ptr<TensorTransform> resize(new Resize(resizeShape));
std::shared_ptr<TensorTransform> normalize(
new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
Execute composeDecode({decode, resize, normalize, hwc2chw});
auto image = ReadFileToTensor(all_files[i]);
composeDecode(image, img.get());
for (size_t k = 0; k < n_attrs; ++k) {
gettimeofday(&start, nullptr);
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
size_t index = i * n_attrs + k;
std::cout << static_cast<color_rep_type>(img->DataType()) << std::endl;
inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize());
inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(),
all_cfg[index].Data().get(), all_cfg[index].DataSize());
Status gen_model_ret = gen_model.Predict(inputs, &outputs);
if (gen_model_ret != kSuccess) {
std::cout << "Generator inference " << all_files[i] << " failed." << std::endl;
return 1;
}
Denorm(&outputs);
int pos = all_files[i].find('.');
std::string fileName = all_files[i].substr(0, pos);
WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs);
gettimeofday(&end, nullptr);
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <sstream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
std::cout << fileName << std::endl;
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
{static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
std::vector<std::string> split(std::string inputs) {
std::vector<std::string> line;
std::stringstream stream(inputs);
std::string result;
while ( stream >> result ) {
line.push_back(result);
}
return line;
}
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr) {
std::vector<mindspore::MSTensor> res;
if (file.empty()) {
std::cout << "Pointer file is nullptr." << std::endl;
exit(1);
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist." << std::endl;
exit(1);
}
if (!ifs.is_open()) {
std::cout << "File: " << file << " open failed." << std::endl;
exit(1);
}
std::string n_images;
std::string n_attrs;
getline(ifs, n_images);
getline(ifs, n_attrs);
auto n_images_ = std::stoi(n_images);
auto n_attrs_ = std::stoi(n_attrs);
*n_ptr = n_attrs_;
std::cout << "Image number is " << n_images << std::endl;
std::cout << "Attribute number is " << n_attrs << std::endl;
auto all_lines = n_images_ * n_attrs_;
for (auto i = 0; i < all_lines; i++) {
std::string val;
getline(ifs, val);
std::vector<std::string> val_split = split(val);
void *data = malloc(sizeof(float)*n_attrs_);
float *elements = reinterpret_cast<float *>(data);
for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str());
auto size = sizeof(float) * n_attrs_;
mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32,
{static_cast<int64_t>(size)}, nullptr, size);
memcpy(buffer.MutableData(), elements, size);
res.emplace_back(buffer);
}
ifs.close();
return res;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
void Denorm(std::vector<MSTensor> *outputs) {
for (size_t i = 0; i < outputs->size(); ++i) {
size_t outputSize = (*outputs)[i].DataSize();
float* netOutput = reinterpret_cast<float *>((*outputs)[i].MutableData());
size_t outputLen = outputSize / sizeof(float);
for (size_t j = 0; j < outputLen; ++j) {
netOutput[j] = (netOutput[j] + 1) / 2 * 255;
netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j];
netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j];
}
}
}

View File

@ -0,0 +1,147 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Entry point for testing AttGAN network"""
import argparse
import json
import math
import os
from os.path import join
import numpy as np
from PIL import Image
import mindspore.common.dtype as mstype
import mindspore.dataset as de
from mindspore import context, Tensor, ops
from mindspore.train.serialization import load_param_into_net
from src.attgan import Gen
from src.cell import init_weights
from src.data import check_attribute_conflict
from src.data import get_loader, Custom
from src.helpers import Progressbar
from src.utils import resume_generator, denorm
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
def parse(arg=None):
"""Define configuration of Evaluation"""
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--num_test', dest='num_test', type=int)
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--custom_img', action='store_true')
parser.add_argument('--custom_data', type=str, default='../data/custom')
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
return parser.parse_args(arg)
args_ = parse()
print(args_)
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.test_int = args_.test_int
args.num_test = args_.num_test
args.gen_ckpt_name = args_.gen_ckpt_name
args.custom_img = args_.custom_img
args.custom_data = args_.custom_data
args.custom_attr = args_.custom_attr
args.shortcut_layers = args_.shortcut_layers
args.inject_layers = args_.inject_layers
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)
print(args)
# Data loader
if args.custom_img:
output_path = join("output", args.experiment_name, "custom_testing")
os.makedirs(output_path, exist_ok=True)
test_dataset = Custom(args.custom_data, args.custom_attr, args.attrs)
test_len = len(test_dataset)
else:
output_path = join("output", args.experiment_name, "sample_testing")
os.makedirs(output_path, exist_ok=True)
test_dataset = get_loader(args.data_path, args.attr_path,
selected_attrs=args.attrs,
mode="test"
)
test_len = len(test_dataset)
dataset_column_names = ["image", "attr"]
num_parallel_workers = 8
ds = de.GeneratorDataset(test_dataset, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers))
ds = ds.batch(1, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=False)
test_dataset_iter = ds.create_dict_iterator()
if args.num_test is None:
print('Testing images:', test_len)
else:
print('Testing images:', min(test_len, args.num_test))
# Model loader
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test')
# Initialize network
init_weights(gen, 'KaimingUniform', math.sqrt(5))
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
load_param_into_net(gen, para_gen)
print("Network initializes successfully.")
progressbar = Progressbar()
it = 0
for data in test_dataset_iter:
img_a = data["image"]
att_a = data["attr"]
if args.num_test is not None and it == args.num_test:
break
att_a = Tensor(att_a, mstype.float32)
att_b_list = [att_a]
for i in range(args.n_attrs):
clone = ops.Identity()
tmp = clone(att_a)
tmp[:, i] = 1 - tmp[:, i]
tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
att_b_list.append(tmp)
samples = [img_a]
for i, att_b in enumerate(att_b_list):
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
samples.append(gen(img_a, att_b_, mode="enc-dec"))
cat = ops.Concat(axis=3)
samples = cat(samples).asnumpy()
result = denorm(samples)
result = np.reshape(result, (128, -1, 3))
im = Image.fromarray(np.uint8(result))
if args.custom_img:
out_file = test_dataset.images[it]
else:
out_file = "{:06d}.jpg".format(it + 182638)
im.save(output_path + '/' + out_file)
print('Successful save image in ' + output_path + '/' + out_file)
it += 1

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file."""
import argparse
import json
from os.path import join
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_param_into_net
from src.utils import resume_generator
from src.attgan import Gen
parser = argparse.ArgumentParser(description='Attribute Edit')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
args_ = parser.parse_args()
print(args_)
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.device_id = args_.device_id
args.batch_size = args_.batch_size
args.gen_ckpt_name = args_.gen_ckpt_name
args.file_format = args_.file_format
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
gen = Gen(mode="test")
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
load_param_into_net(gen, para_gen)
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
Gen_file = f"attgan_mindir"
export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format)

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from PIL import Image
def parse(arg=None):
"""Define configuration of postprocess"""
parser = argparse.ArgumentParser()
parser.add_argument('--bin_path', type=str, default='./result_Files/')
parser.add_argument('--target_path', type=str, default='./result_Files/')
return parser.parse_args(arg)
def load_bin_file(bin_file, shape=None, dtype="float32"):
"""Load data from bin file"""
data = np.fromfile(bin_file, dtype=dtype)
if shape:
data = np.reshape(data, shape)
return data
def save_bin_to_image(data, out_name):
"""Save bin file to image arrays"""
image = np.transpose(data, (1, 2, 0))
im = Image.fromarray(np.uint8(image))
im.save(out_name)
print("Successfully save image in " + out_name)
def scan_dir(bin_path):
"""Scan directory"""
out = os.listdir(bin_path)
return out
def postprocess(bin_path):
"""Post process bin file"""
file_list = scan_dir(bin_path)
for file in file_list:
data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32")
pos = file.find(".")
file_name = file[0:pos] + "." + "jpg"
outfile = os.path.join(args.target_path, file_name)
save_bin_to_image(data, outfile)
if __name__ == "__main__":
args = parse()
postprocess(args.bin_path)

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""pre process for 310 inference"""
import os
from os.path import join
import argparse
import numpy as np
selected_attrs = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
def parse(arg=None):
"""Define configuration of preprocess"""
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn')
parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt')
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
return parser.parse_args(arg)
args = parse()
args.n_attrs = len(args.attrs)
def check_attribute_conflict(att_batch, att_name, att_names):
"""Check Attributes"""
def _set(att, att_name):
if att_name in att_names:
att[att_names.index(att_name)] = 0.0
att_id = att_names.index(att_name)
for att in att_batch:
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
_set(att, 'Bangs')
elif att_name == 'Bangs' and att[att_id] != 0:
_set(att, 'Bald')
_set(att, 'Receding_Hairline')
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
for n in ['Straight_Hair', 'Wavy_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
for n in ['Mustache', 'No_Beard']:
if n != att_name:
_set(att, n)
return att_batch
def read_cfg_file(attr_path):
"""Read configuration from attribute file"""
attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split()
atts = [attr_list.index(att) + 1 for att in selected_attrs]
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0])
labels = [labels] if attr_number == 1 else labels[0:]
new_attr = []
for index in range(attr_number):
att = [np.asarray((labels[index] + 1) // 2)]
new_attr.append(att)
new_attr = np.array(new_attr)
return new_attr, attr_number
def preprocess_cfg(attrs, numbers):
"""Preprocess attribute file"""
new_attr = []
for index in range(numbers):
attr = attrs[index]
att_b_list = [attr]
for i in range(args.n_attrs):
tmp = attr.copy()
tmp[:, i] = 1 - tmp[:, i]
tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs)
att_b_list.append(tmp)
for i, att_b in enumerate(att_b_list):
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
new_attr.append(att_b_)
return new_attr
def write_cfg_file(attrs, numbers):
"""Write attribute file"""
cur_dir = os.getcwd()
print(cur_dir)
path = join(cur_dir, 'attrs.txt')
with open(path, "w") as f:
f.writelines(str(numbers))
f.writelines("\n")
f.writelines(str(args.n_attrs))
f.writelines("\n")
counts = numbers * args.n_attrs
for index in range(counts):
attrs_list = attrs[index][0]
new_attrs_list = ["%s" % x for x in attrs_list]
sequence = " ".join(new_attrs_list)
f.writelines(sequence)
f.writelines("\n")
print("Generate cfg file successfully.")
if __name__ == "__main__":
if args.attrs_path is None:
print("Path is not correct!")
attributes, n_images = read_cfg_file(args.attrs_path)
new_attrs = preprocess_cfg(attributes, n_images)
write_cfg_file(new_attrs, n_images)

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [ATTR_PATH]"
exit 1
fi
export MINDSPORE_HCCL_CONFIG_PATH=$1
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
echo "hccl connect time out has changed to 600 second"
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
data_path=$2
attr_path=$3
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cd ./LOG$i || exit
echo "Start training for rank $i, device $DEVICE_ID"
env > env.log
cd ../../
taskset -c $cmdopt python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 128_shortcut1_inject1_none \
--data_path $data_path \
--attr_path $attr_path \
--run_distribute 1 > ./scripts/LOG$i/log.txt 2>&1 &
cd scripts
done

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]"
exit 1
fi
experiment_name=$1
data_path=$2
attr_path=$3
gen_ckpt_name=$4
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "The number of logical core" $cores
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf EVAL_LOG
mkdir ./EVAL_LOG
cd ./EVAL_LOG || exit
echo "Start training for rank 0, device 0, directory is EVAL_LOG"
env > env.log
cd ../../
python eval.py \
--experiment_name $experiment_name \
--test_int 1.0 \
--custom_data $data_path \
--custom_attr $attr_path \
--custom_img \
--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 &

View File

@ -0,0 +1,128 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
gen_model=$(get_real_path $1)
attr_path=$(get_real_path $2)
data_path=$(get_real_path $3)
if [ "$4" == "y" ] || [ "$4" == "n" ];then
need_preprocess=$4
else
echo "weather need preprocess or not, it's value must be in [y, n]"
exit 1
fi
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
echo "generator mindir name: "$gen_model
echo "attribute file path: "$attr_path
echo "dataset path: "$data_path
echo "need preprocess: "$need_preprocess
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
echo "Start to preprocess attr file..."
python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log
echo "Attribute file generates successfully!"
}
function compile_app()
{
echo "Start to compile source code..."
cd ../ascend310_infer || exit
bash build.sh &> build.log
echo "Compile successfully."
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
echo "Start to execute inference..."
../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log
}
function postprocess_data()
{
echo "Start to postprocess image file..."
python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/"
}
if [ $need_preprocess == "y" ]; then
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess attrs failed"
exit 1
fi
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo "execute inference failed"
exit 1
fi
postprocess_data
if [ $? -ne 0 ]; then
echo "postprocess images failed"
exit 1
fi

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_single_train.sh [EXPERIMENT_NAME] [DATA_PATH] [ATTR_PATH]"
exit 1
fi
experiment_name=$1
data_path=$2
attr_path=$3
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf LOG
mkdir ./LOG
cd ./LOG || exit
echo "Start training for rank 0, device 0, directory is LOG"
env > env.log
cd ../../
python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name $experiment_name \
--data_path $data_path \
--attr_path $attr_path > ./scripts/LOG/log.txt 2>&1 &

View File

@ -0,0 +1,135 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AttGAN Network Topology"""
import mindspore.ops.operations as P
from mindspore import nn
from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
# Image size 128 x 128
MAX_DIM = 64 * 16
class Gen(nn.Cell):
"""Generator"""
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu",
n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"):
super().__init__()
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
self.inject_layers = min(inject_layers, dec_layers - 1)
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
layers = []
n_in = 3
for i in range(enc_layers):
n_out = min(enc_dim * 2 ** i, MAX_DIM)
layers += [Conv2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=enc_norm_fn, acti_fn=enc_acti_fn, mode=mode
)]
n_in = n_out
self.enc_layers = nn.CellList(layers)
layers = []
n_in = n_in + n_attrs # 1024 + 13
for i in range(dec_layers):
if i < dec_layers - 1:
n_out = min(dec_dim * 2 ** (dec_layers - i - 1), MAX_DIM)
layers += [ConvTranspose2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=dec_norm_fn, acti_fn=dec_acti_fn, mode=mode
)]
n_in = n_out
n_in = n_in + n_in // 2 if self.shortcut_layers > i else n_in
n_in = n_in + n_attrs if self.inject_layers > i else n_in
else:
layers += [ConvTranspose2dBlock(
n_in, 3, (4, 4), stride=2, padding=1, norm_fn='none', acti_fn='tanh', mode=mode
)]
self.dec_layers = nn.CellList(layers)
self.view = P.Reshape()
self.repeat = P.Tile()
self.cat = P.Concat(1)
def encoder(self, x):
"""Encoder construct"""
z = x
zs = []
for layer in self.enc_layers:
z = layer(z)
zs.append(z)
return zs
def decoder(self, zs, a):
"""Decoder construct"""
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
multiples = (1, 1, self.f_size, self.f_size)
a_tile = self.repeat(a_tile, multiples)
z = self.cat((zs[-1], a_tile))
i = 0
for layer in self.dec_layers:
z = layer(z)
if self.shortcut_layers > i:
z = self.cat((z, zs[len(self.dec_layers) - 2 - i]))
if self.inject_layers > i:
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
multiples = (1, 1, self.f_size * 2 ** (i + 1), self.f_size * 2 ** (i + 1))
a_tile = self.repeat(a_tile, multiples)
z = self.cat((z, a_tile))
i = i + 1
return z
def construct(self, x, a=None, mode="enc-dec"):
result = None
if mode == "enc-dec":
out = self.encoder(x)
result = self.decoder(out, a)
if mode == "enc":
result = self.encoder(x)
if mode == "dec":
result = self.decoder(x, a)
return result
class Dis(nn.Cell):
"""Discriminator"""
def __init__(self, dim=64, norm_fn='none', acti_fn='lrelu',
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128, mode='test'):
super().__init__()
self.f_size = img_size // 2 ** n_layers
layers = []
n_in = 3
for i in range(n_layers):
n_out = min(dim * 2 ** i, MAX_DIM)
layers += [Conv2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=norm_fn, acti_fn=acti_fn, mode=mode
)]
n_in = n_out
self.conv = nn.SequentialCell(layers)
self.fc_adv = nn.SequentialCell(
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
LinearBlock(fc_dim, 1, 'none', 'none', mode)])
self.fc_cls = nn.SequentialCell(
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
LinearBlock(fc_dim, 13, 'none', 'none', mode)])
def construct(self, x):
"""construct"""
h = self.conv(x)
view = P.Reshape()
h = view(h, (h.shape[0], -1))
return self.fc_adv(h), self.fc_cls(h)

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network Component"""
from mindspore import nn
def add_normalization_1d(layers, fn, n_out, mode='test'):
if fn == "none":
pass
elif fn == "batchnorm":
layers.append(nn.BatchNorm1d(n_out, use_batch_statistics=(mode == 'train')))
elif fn == "instancenorm":
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
else:
raise Exception('Unsupported normalization: ' + str(fn))
return layers
def add_normalization_2d(layers, fn, n_out, mode='test'):
if fn == 'none':
pass
elif fn == 'batchnorm':
layers.append(nn.BatchNorm2d(n_out, use_batch_statistics=(mode == 'train')))
elif fn == "instancenorm":
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
else:
raise Exception('Unsupported normalization: ' + str(fn))
return layers
def add_activation(layers, fn):
"""Add Activation"""
if fn == "none":
pass
elif fn == "relu":
layers.append(nn.ReLU())
elif fn == "lrelu":
layers.append(nn.LeakyReLU(alpha=0.01))
elif fn == "sigmoid":
layers.append(nn.Sigmoid())
elif fn == "tanh":
layers.append(nn.Tanh())
else:
raise Exception('Unsupported activation function: ' + str(fn))
return layers
class LinearBlock(nn.Cell):
"""Linear Block"""
def __init__(self, n_in, n_out, norm_fn="none", acti_fn="none", mode='test'):
super().__init__()
layers = [nn.Dense(n_in, n_out, has_bias=(norm_fn == 'none'))]
layers = add_normalization_1d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
class Conv2dBlock(nn.Cell):
"""Convolution Block"""
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
norm_fn=None, acti_fn=None, mode='test'):
super().__init__()
layers = [nn.Conv2d(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
has_bias=(norm_fn == 'none'))]
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
class ConvTranspose2dBlock(nn.Cell):
"""Transpose Convolution Block"""
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
norm_fn=None, acti_fn=None, mode='test'):
super().__init__()
layers = [nn.Conv2dTranspose(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
has_bias=(norm_fn == 'none'))]
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cell Definition"""
import numpy as np
import mindspore.ops.functional as F
import mindspore.ops.operations as P
from mindspore import nn, ops
from mindspore.common import initializer as init, set_seed
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
set_seed(1)
np.random.seed(1)
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'KaimingUniform':
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, (nn.GroupNorm, nn.BatchNorm2d)):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class GenWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return generator loss
"""
def __init__(self, network):
super().__init__(auto_prefix=False)
self.network = network
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
_, g_loss, _, _, _, = self.network(img_a, att_a, att_a_, att_b, att_b_)
return g_loss
class DisWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return discriminator loss
"""
def __init__(self, network):
super().__init__(auto_prefix=False)
self.network = network
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
d_loss, _, _, _, _ = self.network(img_a, att_a, att_a_, att_b, att_b_)
return d_loss
class TrainOneStepCellGen(nn.Cell):
"""Encapsulation class of AttGAN generator network training."""
def __init__(self, generator, optimizer, sens=1.0):
super().__init__()
self.optimizer = optimizer
self.generator = generator
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = GenWithLossCell(generator)
self.network.add_flags(defer_inline=True)
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
weights = self.weights
_, loss, gf_loss, gc_loss, gr_loss = self.generator(img_a, att_a, att_a_, att_b, att_b_)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
class TrainOneStepCellDis(nn.Cell):
"""Encapsulation class of AttGAN discriminator network training."""
def __init__(self, discriminator, optimizer, sens=1.0):
super().__init__()
self.optimizer = optimizer
self.discriminator = discriminator
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = DisWithLossCell(discriminator)
self.network.add_flags(defer_inline=True)
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
weights = self.weights
loss, d_real_loss, d_fake_loss, dc_loss, df_gp = self.discriminator(img_a, att_a, att_a_, att_b, att_b_)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp

View File

@ -0,0 +1,196 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" DataLoader: CelebA"""
import os
import numpy as np
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore.dataset.transforms import py_transforms
from src.utils import DistributedSampler
class Custom:
"""
Custom data loader
"""
def __init__(self, data_path, attr_path, selected_attrs):
self.data_path = data_path
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = [py_vision.ToPIL()]
transform.append(py_vision.Resize([128, 128]))
transform.append(py_vision.ToTensor())
transform.append(py_vision.Normalize(mean=mean, std=std))
transform = py_transforms.Compose(transform)
self.transform = transform
self.images = np.array([images]) if images.size == 1 else images[0:]
self.labels = np.array([labels]) if images.size == 1 else labels[0:]
self.length = len(self.images)
def __getitem__(self, index):
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
att = np.asarray((self.labels[index] + 1) // 2)
image = np.squeeze(self.transform(image))
return image, att
def __len__(self):
return self.length
class CelebA:
"""
CelebA dataset
Input:
data_path: Image Path
attr_path: Attr_list Path
image_size: Image Size
mode: Train or Test
selected_attrs: selected attributes
transform: Image Processing
"""
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs, transform, split_point=182000):
super().__init__()
self.data_path = data_path
self.transform = transform
self.img_size = image_size
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
if mode == "train":
self.images = images[:split_point]
self.labels = labels[:split_point]
if mode == "test":
self.images = images[split_point:]
self.labels = labels[split_point:]
self.length = len(self.images)
def __getitem__(self, index):
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
att = np.asarray((self.labels[index] + 1) // 2)
image = np.squeeze(self.transform(image))
return image, att
def __len__(self):
return self.length
def get_loader(data_root, attr_path, selected_attrs, crop_size=170, image_size=128, mode="train", split_point=182000):
"""Build and return dataloader"""
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = [py_vision.ToPIL()]
transform.append(py_vision.CenterCrop((crop_size, crop_size)))
transform.append(py_vision.Resize([image_size, image_size]))
transform.append(py_vision.ToTensor())
transform.append(py_vision.Normalize(mean=mean, std=std))
transform = py_transforms.Compose(transform)
dataset = CelebA(data_root, attr_path, image_size, mode, selected_attrs, transform, split_point=split_point)
return dataset
def data_loader(img_path, attr_path, selected_attrs, mode="train", batch_size=1, device_num=1, rank=0, shuffle=True,
split_point=182000):
"""CelebA data loader"""
num_parallel_workers = 8
dataset_loader = get_loader(img_path, attr_path, selected_attrs, mode=mode, split_point=split_point)
length_dataset = len(dataset_loader)
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
dataset_column_names = ["image", "attr"]
if device_num != 8:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers),
sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
# ds = ds.repeat(200)
return ds, length_dataset
def check_attribute_conflict(att_batch, att_name, att_names):
"""Check Attributes"""
def _set(att, att_name):
if att_name in att_names:
att[att_names.index(att_name)] = 0.0
att_id = att_names.index(att_name)
for att in att_batch:
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
_set(att, 'Bangs')
elif att_name == 'Bangs' and att[att_id] != 0:
_set(att, 'Bald')
_set(att, 'Receding_Hairline')
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
for n in ['Straight_Hair', 'Wavy_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
for n in ['Mustache', 'No_Beard']:
if n != att_name:
_set(att, n)
return att_batch
if __name__ == "__main__":
attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to test')
parser.add_argument('--data_path', dest='data_path', type=str, required=True)
parser.add_argument('--attr_path', dest='attr_path', type=str, required=True)
args = parser.parse_args()
####### Test CelebA #######
context.set_context(device_target="Ascend")
dataset_ce, length_ce = data_loader(args.data_path, args.attr_path, attrs_default, mode="train")
i = 0
for data in dataset_ce.create_dict_iterator():
print('Number:', i, 'Value:', data["attr"], 'Type:', type(data["attr"]))
i += 1

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================s
"""Helper functions for training"""
import datetime
import platform
from tqdm import tqdm
def name_experiment(prefix="", suffix=""):
experiment_name = datetime.datetime.now().strftime('%b%d_%H-%M-%S_') + platform.node()
if prefix is not None and prefix != '':
experiment_name = prefix + '_' + experiment_name
if suffix is not None and suffix != '':
experiment_name = experiment_name + '_' + suffix
return experiment_name
class Progressbar():
"""Progress Bar"""
def __init__(self):
self.p = None
def __call__(self, iterable, length):
self.p = tqdm(iterable, total=length)
return self.p
def say(self, **kwargs):
if self.p is not None:
self.p.set_postfix(**kwargs)

View File

@ -0,0 +1,164 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================s
"""Loss Computation of Generator and Discriminator"""
import numpy as np
import mindspore
import mindspore.ops.operations as P
from mindspore import dtype as mstype
from mindspore import nn, Tensor, ops
from mindspore.ops import constexpr
class ClassificationLoss(nn.Cell):
"""Define classification loss for AttGAN"""
def __init__(self):
super().__init__()
self.bce_loss = P.BinaryCrossEntropy(reduction='sum')
def construct(self, pred, label):
weight = ops.Ones()(pred.shape, mindspore.float32)
pred_ = P.Sigmoid()(pred)
x = self.bce_loss(pred_, label, weight) / pred.shape[0]
return x
@constexpr
def generate_tensor(batch_size):
np_array = np.random.randn(batch_size, 1, 1, 1)
return Tensor(np_array, mindspore.float32)
class GradientWithInput(nn.Cell):
"""Get Discriminator Gradient with Input"""
def __init__(self, discriminator):
super().__init__()
self.reduce_sum = ops.ReduceSum()
self.discriminator = discriminator
self.discriminator.set_train(mode=True)
def construct(self, interpolates):
decision_interpolate, _ = self.discriminator(interpolates)
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
return decision_interpolate
class WGANGPGradientPenalty(nn.Cell):
"""Define WGAN loss for AttGAN"""
def __init__(self, discriminator):
super().__init__()
self.gradient_op = ops.GradOperation()
self.reduce_sum = ops.ReduceSum()
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
self.sqrt = ops.Sqrt()
self.discriminator = discriminator
self.GradientWithInput = GradientWithInput(discriminator)
def construct(self, x_real, x_fake):
"""get gradient penalty"""
batch_size = x_real.shape[0]
alpha = generate_tensor(batch_size)
alpha = alpha.expand_as(x_real)
x_fake = ops.functional.stop_gradient(x_fake)
x_hat = x_real + alpha * (x_fake - x_real)
gradient = self.gradient_op(self.GradientWithInput)(x_hat)
gradient_1 = ops.reshape(gradient, (batch_size, -1))
gradient_1 = self.sqrt(self.reduce_sum(gradient_1 * gradient_1, 1))
gradient_penalty = self.reduce_sum((gradient_1 - 1.0) ** 2) / x_real.shape[0]
return gradient_penalty
class GenLoss(nn.Cell):
"""Define total Generator loss"""
def __init__(self, args, generator, discriminator):
super().__init__()
self.generator = generator
self.discriminator = discriminator
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
self.cyc_loss = P.ReduceMean()
self.cls_loss = nn.BCEWithLogitsLoss()
self.rec_loss = nn.L1Loss("mean")
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get generator loss"""
# generate
zs_a = self.generator(img_a, mode="enc")
img_fake = self.generator(zs_a, att_b_, mode="dec")
img_recon = self.generator(zs_a, att_a_, mode="dec")
# discriminate
d_fake, dc_fake = self.discriminator(img_fake)
# generator loss
gf_loss = - self.cyc_loss(d_fake)
gc_loss = self.cls_loss(dc_fake, att_b)
gr_loss = self.rec_loss(img_a, img_recon)
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
return (img_fake, g_loss, gf_loss, gc_loss, gr_loss)
class DisLoss(nn.Cell):
"""Define total discriminator loss"""
def __init__(self, args, generator, discriminator):
super().__init__()
self.generator = generator
self.discriminator = discriminator
self.cyc_loss = P.ReduceMean()
self.cls_loss = nn.BCEWithLogitsLoss()
self.WGANLoss = WGANGPGradientPenalty(discriminator)
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get discriminator loss"""
# generate
img_fake = self.generator(img_a, att_b_, mode="enc-dec")
# discriminate
d_real, dc_real = self.discriminator(img_a)
d_fake, _ = self.discriminator(img_fake)
# discriminator losses
d_real_loss = - self.cyc_loss(d_real)
d_fake_loss = self.cyc_loss(d_fake)
df_loss = d_real_loss + d_fake_loss
df_gp = self.WGANLoss(img_a, img_fake)
dc_loss = self.cls_loss(dc_real, att_a)
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss
return (d_loss, d_real_loss, d_fake_loss, dc_loss, df_gp)

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Helper functions"""
import math
import os
import numpy as np
from mindspore import load_checkpoint
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.epoch = 0
self.rank = rank
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
indices = indices.tolist()
self.epoch += 1
else:
indices = list(range(self.dataset_size))
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank: self.total_size: self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def resume_generator(args, generator, gen_ckpt_name):
"""Restore the trained generator"""
print("Loading the trained models from step {}...".format(args.save_interval))
generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name)
param_generator = load_checkpoint(generator_path, generator)
return param_generator
def resume_discriminator(args, discriminator, dis_ckpt_name):
"""Restore the trained discriminator"""
print("Loading the trained models from step {}...".format(args.save_interval))
discriminator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dis_ckpt_name)
param_discriminator = load_checkpoint(discriminator_path, discriminator)
return param_discriminator
def denorm(x):
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
return image_numpy

View File

@ -0,0 +1,242 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Entry point for training AttGAN network"""
import argparse
import datetime
import json
import math
import os
from os.path import join
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
from mindspore.train.serialization import load_param_into_net
from src.attgan import Gen, Dis
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
from src.data import data_loader
from src.helpers import Progressbar
from src.loss import GenLoss, DisLoss
from src.utils import resume_generator, resume_discriminator
attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
def parse(arg=None):
"""Define configuration of Model"""
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to learn')
parser.add_argument('--data', dest='data', type=str, choices=['CelebA'], default='CelebA')
parser.add_argument('--data_path', dest='data_path', type=str, default='./data/img_align_celeba')
parser.add_argument('--attr_path', dest='attr_path', type=str, default='./data/list_attr_celeba.txt')
parser.add_argument('--img_size', dest='img_size', type=int, default=128)
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64)
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64)
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64)
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024)
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5)
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5)
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5)
parser.add_argument('--enc_norm', dest='enc_norm', type=str, default='batchnorm')
parser.add_argument('--dec_norm', dest='dec_norm', type=str, default='batchnorm')
parser.add_argument('--dis_norm', dest='dis_norm', type=str, default='instancenorm')
parser.add_argument('--dis_fc_norm', dest='dis_fc_norm', type=str, default='none')
parser.add_argument('--enc_acti', dest='enc_acti', type=str, default='lrelu')
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu')
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu')
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu')
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0)
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0)
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0)
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0)
parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='# of epochs')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32)
parser.add_argument('--num_workers', dest='num_workers', type=int, default=16)
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5)
parser.add_argument('--beta2', dest='beta2', type=float, default=0.999)
parser.add_argument('--n_d', dest='n_d', type=int, default=5, help='# of d updates per g update')
parser.add_argument('--split_point', dest='split_point', type=int, default=182000, help='# of dataset split point')
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--save_interval', dest='save_interval', type=int, default=500)
parser.add_argument('--experiment_name', dest='experiment_name',
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument('--resume_model', action='store_true')
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--dis_ckpt_name', type=str, default='')
return parser.parse_args(arg)
args = parse()
print(args)
args.lr_base = args.lr
args.n_attrs = len(args.attrs)
# initialize environment
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
if args.run_distribute:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = int(os.getenv('RANK_SIZE'))
print(device_num)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
rank = get_rank()
else:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = int(os.getenv('RANK_SIZE'))
rank = 0
print("Initialize successful!")
os.makedirs(join('output', args.experiment_name), exist_ok=True)
os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True)
with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f:
f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
if __name__ == '__main__':
# Define dataloader
train_dataset, train_length = data_loader(img_path=args.data_path,
attr_path=args.attr_path,
selected_attrs=args.attrs,
mode="train",
batch_size=args.batch_size,
device_num=device_num,
shuffle=True,
split_point=args.split_point)
train_loader = train_dataset.create_dict_iterator()
print('Training images:', train_length)
# Define network
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train')
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
args.dis_layers, args.img_size, mode='train')
# Initialize network
init_weights(gen, 'KaimingUniform', math.sqrt(5))
init_weights(dis, 'KaimingUniform', math.sqrt(5))
# Resume from checkpoint
if args.resume_model:
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
load_param_into_net(gen, para_gen)
load_param_into_net(dis, para_dis)
# Define network with loss
G_loss_cell = GenLoss(args, gen, dis)
D_loss_cell = DisLoss(args, gen, dis)
# Define Optimizer
optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
# Define One Step Train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizer_G)
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizer_D)
# Train
G_trainOneStep.set_train(True)
D_trainOneStep.set_train(True)
print("Start Training")
train_iter = train_length // args.batch_size
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_interval)
if rank == 0:
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator')
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
cb_params_gen = _InternalCallbackParam()
cb_params_gen.train_network = gen
cb_params_gen.cur_epoch_num = 0
gen_run_context = RunContext(cb_params_gen)
ckpt_cb_gen.begin(gen_run_context)
cb_params_dis = _InternalCallbackParam()
cb_params_dis.train_network = dis
cb_params_dis.cur_epoch_num = 0
dis_run_context = RunContext(cb_params_dis)
ckpt_cb_dis.begin(dis_run_context)
# Initialize Progressbar
progressbar = Progressbar()
it = 0
for epoch in range(args.epochs):
for data in progressbar(train_loader, train_iter):
img_a = data["image"]
att_a = data["attr"]
att_a = att_a.asnumpy()
att_b = np.random.permutation(att_a)
att_a_ = (att_a * 2 - 1) * args.thres_int
att_b_ = (att_b * 2 - 1) * args.thres_int
att_a = Tensor(att_a, mstype.float32)
att_a_ = Tensor(att_a_, mstype.float32)
att_b = Tensor(att_b, mstype.float32)
att_b_ = Tensor(att_b_, mstype.float32)
if (it + 1) % (args.n_d + 1) != 0:
d_out, d_real_loss, d_fake_loss, dc_loss, df_gp = D_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
else:
g_out, gf_loss, gc_loss, gr_loss = G_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
progressbar.say(epoch=epoch, iter=it + 1, d_loss=d_out, g_loss=g_out, gf_loss=gf_loss, gc_loss=gc_loss,
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
cb_params_gen.cur_epoch_num = epoch + 1
cb_params_dis.cur_epoch_num = epoch + 1
cb_params_gen.cur_step_num = it + 1
cb_params_dis.cur_step_num = it + 1
cb_params_gen.batch_num = it + 2
cb_params_dis.batch_num = it + 2
ckpt_cb_gen.step_end(gen_run_context)
ckpt_cb_dis.step_end(dis_run_context)
it += 1