!22516 AttGAN ascend310 commit

Merge pull request !22516 from MR.D/dev
This commit is contained in:
i-robot 2021-08-28 10:15:25 +00:00 committed by Gitee
commit fc746f737c
17 changed files with 865 additions and 136 deletions

View File

@ -17,6 +17,8 @@
- [评估](#评估)
- [推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
- [在Ascend310执行推理](#在Ascend310执行推理)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
@ -53,7 +55,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
@ -70,17 +72,17 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
export RANK_SIZE=1
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
OR
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt
# 运行分布式训练示例
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
# 运行评估示例
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
OR
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
@ -90,7 +92,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将两个检查点文件Encoder和Decoder的名称作为参数传入。
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将检查点文件Generator的名称作为参数传入。
[注意] 以上路径均应设置为绝对路径
@ -102,10 +104,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
.
└─ cv
└─ AttGAN
├── ascend310_infer # 310推理目录
├── scripts
├──run_distribute_train.sh # 分布式训练的shell脚本
├──run_single_train.sh # 单卡训练的shell脚本
├──run_eval.sh # 推理脚本
├──run_eval.sh # 评估脚本
├──run_infer_310.sh # 推理脚本
├─ src
├─ __init__.py # 初始化文件
├─ block.py # 基础cell
@ -117,6 +121,9 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
├─ loss.py # loss计算
├─ eval.py # 测试脚本
├─ train.py # 训练脚本
├─ export.py # MINDIR模型导出脚本
├─ preprocess.py # 310推理预处理脚本
├─ postprocess.py # 310推理后处理脚本
└─ README_CN.md # AttGAN的文件描述
```
@ -141,7 +148,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
- Ascend处理器环境运行
```bash
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
```
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
@ -153,12 +160,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
- 在Ascend环境运行时评估自定义数据集
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后需要将自定义图片目录和属性编辑文件作为参数传入测试脚本分别对应custom_data以及custom_attr。
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数)
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数)
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
```
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
@ -168,7 +175,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
### 导出MindIR
```shell
python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT]
python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT]
```
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
@ -176,6 +183,26 @@ python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CK
脚本会在当前目录下生成对应的MINDIR文件。
### 在Ascend310执行推理
在执行推理前必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑
```bash
bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
```
- `MINDIR_PATH` MINDIR文件的路径
- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径
- `DATA_PATH` 需要进行推理的数据集目录图像格式应当为jpg
- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理可以在y或者n中选择如果选择y表示进行预处理在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间)
- `DEVICE_ID` 可选默认值为0.
[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件第一行为要推理的图片数目第二行为要编辑的属性接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。
### 结果
推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果是否对该属性进行编辑根据属性编辑文件的内容决定。
# 模型描述
## 性能

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
DIR *OpenDir(std::string_view dirName);
void Denorm(std::vector<mindspore::MSTensor> *outputs);
std::string RealPath(std::string_view path);
std::vector<std::string> GetAllFiles(std::string_view dirName);
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,149 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Decode;
using color_rep_type = std::underlying_type<mindspore::DataType>::type;
DEFINE_string(gen_mindir_path, "", "generator mindir path");
DEFINE_string(dataset_path, "", "dataset path");
DEFINE_string(attr_file_path, "", "attribute file path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_int32(image_height, 128, "image height");
DEFINE_int32(image_width, 128, "image width");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_gen_mindir_path).empty()) {
std::cout << "Invalid generator mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph gen_graph;
Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph);
Model gen_model;
Status gen_ret = gen_model.Build(GraphCell(gen_graph), context);
if (gen_ret != kSuccess) {
std::cout << "ERROR: Generator build failed." << std::endl;
return 1;
}
size_t n_attrs = 13;
auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs);
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
double startTimeMs;
double endTimeMs;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto img = std::make_shared<MSTensor>();
std::shared_ptr<TensorTransform> decode(new Decode());
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
auto resizeShape = {FLAGS_image_height, FLAGS_image_width};
std::shared_ptr<TensorTransform> resize(new Resize(resizeShape));
std::shared_ptr<TensorTransform> normalize(
new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
Execute composeDecode({decode, resize, normalize, hwc2chw});
auto image = ReadFileToTensor(all_files[i]);
composeDecode(image, img.get());
for (size_t k = 0; k < n_attrs; ++k) {
gettimeofday(&start, nullptr);
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
size_t index = i * n_attrs + k;
std::cout << static_cast<color_rep_type>(img->DataType()) << std::endl;
inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize());
inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(),
all_cfg[index].Data().get(), all_cfg[index].DataSize());
Status gen_model_ret = gen_model.Predict(inputs, &outputs);
if (gen_model_ret != kSuccess) {
std::cout << "Generator inference " << all_files[i] << " failed." << std::endl;
return 1;
}
Denorm(&outputs);
int pos = all_files[i].find('.');
std::string fileName = all_files[i].substr(0, pos);
WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs);
gettimeofday(&end, nullptr);
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <sstream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
std::cout << fileName << std::endl;
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
{static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
std::vector<std::string> split(std::string inputs) {
std::vector<std::string> line;
std::stringstream stream(inputs);
std::string result;
while ( stream >> result ) {
line.push_back(result);
}
return line;
}
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr) {
std::vector<mindspore::MSTensor> res;
if (file.empty()) {
std::cout << "Pointer file is nullptr." << std::endl;
exit(1);
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist." << std::endl;
exit(1);
}
if (!ifs.is_open()) {
std::cout << "File: " << file << " open failed." << std::endl;
exit(1);
}
std::string n_images;
std::string n_attrs;
getline(ifs, n_images);
getline(ifs, n_attrs);
auto n_images_ = std::stoi(n_images);
auto n_attrs_ = std::stoi(n_attrs);
*n_ptr = n_attrs_;
std::cout << "Image number is " << n_images << std::endl;
std::cout << "Attribute number is " << n_attrs << std::endl;
auto all_lines = n_images_ * n_attrs_;
for (auto i = 0; i < all_lines; i++) {
std::string val;
getline(ifs, val);
std::vector<std::string> val_split = split(val);
void *data = malloc(sizeof(float)*n_attrs_);
float *elements = reinterpret_cast<float *>(data);
for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str());
auto size = sizeof(float) * n_attrs_;
mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32,
{static_cast<int64_t>(size)}, nullptr, size);
memcpy(buffer.MutableData(), elements, size);
res.emplace_back(buffer);
}
ifs.close();
return res;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
void Denorm(std::vector<MSTensor> *outputs) {
for (size_t i = 0; i < outputs->size(); ++i) {
size_t outputSize = (*outputs)[i].DataSize();
float* netOutput = reinterpret_cast<float *>((*outputs)[i].MutableData());
size_t outputLen = outputSize / sizeof(float);
for (size_t j = 0; j < outputLen; ++j) {
netOutput[j] = (netOutput[j] + 1) / 2 * 255;
netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j];
netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j];
}
}
}

View File

@ -28,12 +28,12 @@ import mindspore.dataset as de
from mindspore import context, Tensor, ops
from mindspore.train.serialization import load_param_into_net
from src.attgan import Genc, Gdec
from src.attgan import Gen
from src.cell import init_weights
from src.data import check_attribute_conflict
from src.data import get_loader, Custom
from src.helpers import Progressbar
from src.utils import resume_model, denorm
from src.utils import resume_generator, denorm
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
@ -45,8 +45,7 @@ def parse(arg=None):
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--num_test', dest='num_test', type=int)
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--custom_img', action='store_true')
parser.add_argument('--custom_data', type=str, default='../data/custom')
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
@ -62,8 +61,7 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.test_int = args_.test_int
args.num_test = args_.num_test
args.enc_ckpt_name = args_.enc_ckpt_name
args.dec_ckpt_name = args_.dec_ckpt_name
args.gen_ckpt_name = args_.gen_ckpt_name
args.custom_img = args_.custom_img
args.custom_data = args_.custom_data
args.custom_attr = args_.custom_attr
@ -101,15 +99,15 @@ else:
print('Testing images:', min(test_len, args.num_test))
# Model loader
genc = Genc(mode='test')
gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test')
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test')
# Initialize network
init_weights(genc, 'KaimingUniform', math.sqrt(5))
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
init_weights(gen, 'KaimingUniform', math.sqrt(5))
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
load_param_into_net(gen, para_gen)
print("Network initializes successfully.")
progressbar = Progressbar()
it = 0
@ -134,8 +132,7 @@ for data in test_dataset_iter:
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
a_enc = genc(img_a)
samples.append(gdec(a_enc, att_b_))
samples.append(gen(img_a, att_b_, mode="enc-dec"))
cat = ops.Concat(axis=3)
samples = cat(samples).asnumpy()
result = denorm(samples)

View File

@ -21,14 +21,13 @@ import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_param_into_net
from src.utils import resume_model
from src.attgan import Genc, Gdec
from src.utils import resume_generator
from src.attgan import Gen
parser = argparse.ArgumentParser(description='Attribute Edit')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
@ -39,27 +38,20 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.device_id = args_.device_id
args.batch_size = args_.batch_size
args.enc_ckpt_name = args_.enc_ckpt_name
args.dec_ckpt_name = args_.dec_ckpt_name
args.gen_ckpt_name = args_.gen_ckpt_name
args.file_format = args_.file_format
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
genc = Genc(mode='test')
gdec = Gdec(mode='test')
gen = Gen(mode="test")
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
load_param_into_net(gen, para_gen)
enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
dec_array = genc(enc_array)
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
G_enc_file = f"AttGAN_Generator_Encoder"
export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format)
G_dec_file = f"AttGAN_Generator_Decoder"
export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format)
Gen_file = f"attgan_mindir"
export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format)

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from PIL import Image
def parse(arg=None):
"""Define configuration of postprocess"""
parser = argparse.ArgumentParser()
parser.add_argument('--bin_path', type=str, default='./result_Files/')
parser.add_argument('--target_path', type=str, default='./result_Files/')
return parser.parse_args(arg)
def load_bin_file(bin_file, shape=None, dtype="float32"):
"""Load data from bin file"""
data = np.fromfile(bin_file, dtype=dtype)
if shape:
data = np.reshape(data, shape)
return data
def save_bin_to_image(data, out_name):
"""Save bin file to image arrays"""
image = np.transpose(data, (1, 2, 0))
im = Image.fromarray(np.uint8(image))
im.save(out_name)
print("Successfully save image in " + out_name)
def scan_dir(bin_path):
"""Scan directory"""
out = os.listdir(bin_path)
return out
def postprocess(bin_path):
"""Post process bin file"""
file_list = scan_dir(bin_path)
for file in file_list:
data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32")
pos = file.find(".")
file_name = file[0:pos] + "." + "jpg"
outfile = os.path.join(args.target_path, file_name)
save_bin_to_image(data, outfile)
if __name__ == "__main__":
args = parse()
postprocess(args.bin_path)

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""pre process for 310 inference"""
import os
from os.path import join
import argparse
import numpy as np
selected_attrs = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
def parse(arg=None):
"""Define configuration of preprocess"""
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn')
parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt')
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
return parser.parse_args(arg)
args = parse()
args.n_attrs = len(args.attrs)
def check_attribute_conflict(att_batch, att_name, att_names):
"""Check Attributes"""
def _set(att, att_name):
if att_name in att_names:
att[att_names.index(att_name)] = 0.0
att_id = att_names.index(att_name)
for att in att_batch:
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
_set(att, 'Bangs')
elif att_name == 'Bangs' and att[att_id] != 0:
_set(att, 'Bald')
_set(att, 'Receding_Hairline')
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
for n in ['Straight_Hair', 'Wavy_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
for n in ['Mustache', 'No_Beard']:
if n != att_name:
_set(att, n)
return att_batch
def read_cfg_file(attr_path):
"""Read configuration from attribute file"""
attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split()
atts = [attr_list.index(att) + 1 for att in selected_attrs]
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0])
labels = [labels] if attr_number == 1 else labels[0:]
new_attr = []
for index in range(attr_number):
att = [np.asarray((labels[index] + 1) // 2)]
new_attr.append(att)
new_attr = np.array(new_attr)
return new_attr, attr_number
def preprocess_cfg(attrs, numbers):
"""Preprocess attribute file"""
new_attr = []
for index in range(numbers):
attr = attrs[index]
att_b_list = [attr]
for i in range(args.n_attrs):
tmp = attr.copy()
tmp[:, i] = 1 - tmp[:, i]
tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs)
att_b_list.append(tmp)
for i, att_b in enumerate(att_b_list):
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
new_attr.append(att_b_)
return new_attr
def write_cfg_file(attrs, numbers):
"""Write attribute file"""
cur_dir = os.getcwd()
print(cur_dir)
path = join(cur_dir, 'attrs.txt')
with open(path, "w") as f:
f.writelines(str(numbers))
f.writelines("\n")
f.writelines(str(args.n_attrs))
f.writelines("\n")
counts = numbers * args.n_attrs
for index in range(counts):
attrs_list = attrs[index][0]
new_attrs_list = ["%s" % x for x in attrs_list]
sequence = " ".join(new_attrs_list)
f.writelines(sequence)
f.writelines("\n")
print("Generate cfg file successfully.")
if __name__ == "__main__":
if args.attrs_path is None:
print("Path is not correct!")
attributes, n_images = read_cfg_file(args.attrs_path)
new_attrs = preprocess_cfg(attributes, n_images)
write_cfg_file(new_attrs, n_images)

View File

@ -14,17 +14,16 @@
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]"
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]"
exit 1
fi
experiment_name=$1
data_path=$2
attr_path=$3
enc_name=$4
dec_name=$5
gen_ckpt_name=$4
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "The number of logical core" $cores
@ -47,5 +46,4 @@ python eval.py \
--custom_data $data_path \
--custom_attr $attr_path \
--custom_img \
--enc_ckpt_name $enc_name \
--dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 &

View File

@ -0,0 +1,128 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
gen_model=$(get_real_path $1)
attr_path=$(get_real_path $2)
data_path=$(get_real_path $3)
if [ "$4" == "y" ] || [ "$4" == "n" ];then
need_preprocess=$4
else
echo "weather need preprocess or not, it's value must be in [y, n]"
exit 1
fi
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
echo "generator mindir name: "$gen_model
echo "attribute file path: "$attr_path
echo "dataset path: "$data_path
echo "need preprocess: "$need_preprocess
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
echo "Start to preprocess attr file..."
python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log
echo "Attribute file generates successfully!"
}
function compile_app()
{
echo "Start to compile source code..."
cd ../ascend310_infer || exit
bash build.sh &> build.log
echo "Compile successfully."
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
echo "Start to execute inference..."
../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log
}
function postprocess_data()
{
echo "Start to postprocess image file..."
python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/"
}
if [ $need_preprocess == "y" ]; then
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess attrs failed"
exit 1
fi
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo "execute inference failed"
exit 1
fi
postprocess_data
if [ $? -ne 0 ]; then
echo "postprocess images failed"
exit 1
fi

View File

@ -22,12 +22,15 @@ from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
# Image size 128 x 128
MAX_DIM = 64 * 16
class Genc(nn.Cell):
"""Generator encoder"""
class Gen(nn.Cell):
"""Generator"""
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
mode='test'):
dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu",
n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"):
super().__init__()
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
self.inject_layers = min(inject_layers, dec_layers - 1)
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
layers = []
n_in = 3
@ -39,27 +42,7 @@ class Genc(nn.Cell):
n_in = n_out
self.enc_layers = nn.CellList(layers)
def construct(self, x):
"""Encoder construct"""
z = x
zs = []
for layer in self.enc_layers:
z = layer(z)
zs.append(z)
return zs
class Gdec(nn.Cell):
"""Generator decoder"""
def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13,
shortcut_layers=1, inject_layers=1, img_size=128, mode='test'):
super().__init__()
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
self.inject_layers = min(inject_layers, dec_layers - 1)
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
layers = []
n_in = 1024
n_in = n_in + n_attrs # 1024 + 13
for i in range(dec_layers):
if i < dec_layers - 1:
@ -80,7 +63,16 @@ class Gdec(nn.Cell):
self.repeat = P.Tile()
self.cat = P.Concat(1)
def construct(self, zs, a):
def encoder(self, x):
"""Encoder construct"""
z = x
zs = []
for layer in self.enc_layers:
z = layer(z)
zs.append(z)
return zs
def decoder(self, zs, a):
"""Decoder construct"""
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
multiples = (1, 1, self.f_size, self.f_size)
@ -100,6 +92,16 @@ class Gdec(nn.Cell):
i = i + 1
return z
def construct(self, x, a=None, mode="enc-dec"):
result = None
if mode == "enc-dec":
out = self.encoder(x)
result = self.decoder(out, a)
if mode == "enc":
result = self.encoder(x)
if mode == "dec":
result = self.decoder(x, a)
return result
class Dis(nn.Cell):
"""Discriminator"""

View File

@ -116,8 +116,7 @@ class TrainOneStepCellGen(nn.Cell):
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss, gf_loss, gc_loss, gr_loss
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
class TrainOneStepCellDis(nn.Cell):
@ -153,5 +152,4 @@ class TrainOneStepCellDis(nn.Cell):
if self.reducer_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss, d_real_loss, d_fake_loss, dc_loss, df_gp
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp

View File

@ -90,10 +90,9 @@ class WGANGPGradientPenalty(nn.Cell):
class GenLoss(nn.Cell):
"""Define total Generator loss"""
def __init__(self, args, encoder, decoder, discriminator):
def __init__(self, args, generator, discriminator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.generator = generator
self.discriminator = discriminator
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
@ -108,9 +107,9 @@ class GenLoss(nn.Cell):
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get generator loss"""
# generate
zs_a = self.encoder(img_a)
img_fake = self.decoder(zs_a, att_b_)
img_recon = self.decoder(zs_a, att_a_)
zs_a = self.generator(img_a, mode="enc")
img_fake = self.generator(zs_a, att_b_, mode="dec")
img_recon = self.generator(zs_a, att_a_, mode="dec")
# discriminate
d_fake, dc_fake = self.discriminator(img_fake)
@ -128,10 +127,9 @@ class GenLoss(nn.Cell):
class DisLoss(nn.Cell):
"""Define total discriminator loss"""
def __init__(self, args, encoder, decoder, discriminator):
def __init__(self, args, generator, discriminator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.generator = generator
self.discriminator = discriminator
self.cyc_loss = P.ReduceMean()
@ -146,8 +144,7 @@ class DisLoss(nn.Cell):
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get discriminator loss"""
# generate
z = self.encoder(img_a)
img_fake = self.decoder(z, att_b_)
img_fake = self.generator(img_a, att_b_, mode="enc-dec")
# discriminate
d_real, dc_real = self.discriminator(img_a)

View File

@ -60,16 +60,13 @@ class DistributedSampler:
def __len__(self):
return self.num_samples
def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name):
"""Restore the trained generator and discriminator"""
def resume_generator(args, generator, gen_ckpt_name):
"""Restore the trained generator"""
print("Loading the trained models from step {}...".format(args.save_interval))
encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name)
decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name)
generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name)
param_generator = load_checkpoint(generator_path, generator)
param_encoder = load_checkpoint(encoder_path, encoder)
param_decoder = load_checkpoint(decoder_path, decoder)
return param_encoder, param_decoder
return param_generator
def resume_discriminator(args, discriminator, dis_ckpt_name):
"""Restore the trained discriminator"""

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================sss
# ============================================================================
"""Entry point for training AttGAN network"""
import argparse
@ -32,12 +32,12 @@ from mindspore.context import ParallelMode
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
from mindspore.train.serialization import load_param_into_net
from src.attgan import Genc, Gdec, Dis
from src.attgan import Gen, Dis
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
from src.data import data_loader
from src.helpers import Progressbar
from src.loss import GenLoss, DisLoss
from src.utils import resume_model, resume_discriminator
from src.utils import resume_generator, resume_discriminator
attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
@ -94,8 +94,7 @@ def parse(arg=None):
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument('--resume_model', action='store_true')
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--gen_ckpt_name', type=str, default='')
parser.add_argument('--dis_ckpt_name', type=str, default='')
return parser.parse_args(arg)
@ -150,32 +149,28 @@ if __name__ == '__main__':
print('Training images:', train_length)
# Define network
genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train')
gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers,
args.inject_layers, args.img_size, mode='train')
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train')
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
args.dis_layers, args.img_size, mode='train')
# Initialize network
init_weights(genc, 'KaimingUniform', math.sqrt(5))
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
init_weights(gen, 'KaimingUniform', math.sqrt(5))
init_weights(dis, 'KaimingUniform', math.sqrt(5))
# Resume from checkpoint
if args.resume_model:
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
load_param_into_net(gen, para_gen)
load_param_into_net(dis, para_dis)
# Define network with loss
G_loss_cell = GenLoss(args, genc, gdec, dis)
D_loss_cell = DisLoss(args, genc, gdec, dis)
G_loss_cell = GenLoss(args, gen, dis)
D_loss_cell = DisLoss(args, gen, dis)
# Define Optimizer
G_trainable_params = genc.trainable_params() + gdec.trainable_params()
optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
# Define One Step Train
@ -193,21 +188,14 @@ if __name__ == '__main__':
if rank == 0:
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder')
ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder')
ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator')
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
cb_params_genc = _InternalCallbackParam()
cb_params_genc.train_network = genc
cb_params_genc.cur_epoch_num = 0
genc_run_context = RunContext(cb_params_genc)
ckpt_cb_genc.begin(genc_run_context)
cb_params_gdec = _InternalCallbackParam()
cb_params_gdec.train_network = gdec
cb_params_gdec.cur_epoch_num = 0
gdec_run_context = RunContext(cb_params_gdec)
ckpt_cb_gdec.begin(gdec_run_context)
cb_params_gen = _InternalCallbackParam()
cb_params_gen.train_network = gen
cb_params_gen.cur_epoch_num = 0
gen_run_context = RunContext(cb_params_gen)
ckpt_cb_gen.begin(gen_run_context)
cb_params_dis = _InternalCallbackParam()
cb_params_dis.train_network = dis
@ -243,16 +231,12 @@ if __name__ == '__main__':
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
cb_params_genc.cur_epoch_num = epoch + 1
cb_params_gdec.cur_epoch_num = epoch + 1
cb_params_gen.cur_epoch_num = epoch + 1
cb_params_dis.cur_epoch_num = epoch + 1
cb_params_genc.cur_step_num = it + 1
cb_params_gdec.cur_step_num = it + 1
cb_params_gen.cur_step_num = it + 1
cb_params_dis.cur_step_num = it + 1
cb_params_genc.batch_num = it + 2
cb_params_gdec.batch_num = it + 2
cb_params_gen.batch_num = it + 2
cb_params_dis.batch_num = it + 2
ckpt_cb_genc.step_end(genc_run_context)
ckpt_cb_gdec.step_end(gdec_run_context)
ckpt_cb_gen.step_end(gen_run_context)
ckpt_cb_dis.step_end(dis_run_context)
it += 1