commit
fc746f737c
|
@ -17,6 +17,8 @@
|
|||
- [评估](#评估)
|
||||
- [推理过程](#推理过程)
|
||||
- [导出MindIR](#导出MindIR)
|
||||
- [在Ascend310执行推理](#在Ascend310执行推理)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
|
@ -53,7 +55,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
@ -70,17 +72,17 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
export RANK_SIZE=1
|
||||
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||
OR
|
||||
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
|
||||
# 运行评估示例
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||
OR
|
||||
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name
|
||||
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
|
||||
|
@ -90,7 +92,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
|
||||
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将两个检查点文件(Encoder和Decoder)的名称作为参数传入。
|
||||
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将检查点文件(Generator)的名称作为参数传入。
|
||||
|
||||
[注意] 以上路径均应设置为绝对路径
|
||||
|
||||
|
@ -102,10 +104,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
.
|
||||
└─ cv
|
||||
└─ AttGAN
|
||||
├── ascend310_infer # 310推理目录
|
||||
├── scripts
|
||||
├──run_distribute_train.sh # 分布式训练的shell脚本
|
||||
├──run_single_train.sh # 单卡训练的shell脚本
|
||||
├──run_eval.sh # 推理脚本
|
||||
├──run_eval.sh # 评估脚本
|
||||
├──run_infer_310.sh # 推理脚本
|
||||
├─ src
|
||||
├─ __init__.py # 初始化文件
|
||||
├─ block.py # 基础cell
|
||||
|
@ -117,6 +121,9 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
├─ loss.py # loss计算
|
||||
├─ eval.py # 测试脚本
|
||||
├─ train.py # 训练脚本
|
||||
├─ export.py # MINDIR模型导出脚本
|
||||
├─ preprocess.py # 310推理预处理脚本
|
||||
├─ postprocess.py # 310推理后处理脚本
|
||||
└─ README_CN.md # AttGAN的文件描述
|
||||
```
|
||||
|
||||
|
@ -141,7 +148,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
|
||||
|
@ -153,12 +160,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
- 在Ascend环境运行时评估自定义数据集
|
||||
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。
|
||||
|
||||
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数)
|
||||
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数)
|
||||
|
||||
```bash
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||
```
|
||||
|
||||
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
|
||||
|
@ -168,7 +175,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
|||
### 导出MindIR
|
||||
|
||||
```shell
|
||||
python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT]
|
||||
python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
|
@ -176,6 +183,26 @@ python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CK
|
|||
|
||||
脚本会在当前目录下生成对应的MINDIR文件。
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
||||
在执行推理前,必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑:
|
||||
|
||||
```bash
|
||||
bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `MINDIR_PATH` MINDIR文件的路径
|
||||
- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径
|
||||
- `DATA_PATH` 需要进行推理的数据集目录,图像格式应当为jpg
|
||||
- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理,可以在y或者n中选择,如果选择y,表示进行预处理(在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间)
|
||||
- `DEVICE_ID` 可选,默认值为0.
|
||||
|
||||
[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件,第一行为要推理的图片数目,第二行为要编辑的属性,接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果,是否对该属性进行编辑根据属性编辑文件的内容决定。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
void Denorm(std::vector<mindspore::MSTensor> *outputs);
|
||||
std::string RealPath(std::string_view path);
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision_ascend.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using color_rep_type = std::underlying_type<mindspore::DataType>::type;
|
||||
|
||||
DEFINE_string(gen_mindir_path, "", "generator mindir path");
|
||||
DEFINE_string(dataset_path, "", "dataset path");
|
||||
DEFINE_string(attr_file_path, "", "attribute file path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
DEFINE_int32(image_height, 128, "image height");
|
||||
DEFINE_int32(image_width, 128, "image width");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_gen_mindir_path).empty()) {
|
||||
std::cout << "Invalid generator mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
|
||||
mindspore::Graph gen_graph;
|
||||
Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph);
|
||||
|
||||
Model gen_model;
|
||||
Status gen_ret = gen_model.Build(GraphCell(gen_graph), context);
|
||||
|
||||
if (gen_ret != kSuccess) {
|
||||
std::cout << "ERROR: Generator build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t n_attrs = 13;
|
||||
auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs);
|
||||
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||
std::map<double, double> costTime_map;
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
size_t size = all_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||
|
||||
auto img = std::make_shared<MSTensor>();
|
||||
std::shared_ptr<TensorTransform> decode(new Decode());
|
||||
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
|
||||
auto resizeShape = {FLAGS_image_height, FLAGS_image_width};
|
||||
std::shared_ptr<TensorTransform> resize(new Resize(resizeShape));
|
||||
std::shared_ptr<TensorTransform> normalize(
|
||||
new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
|
||||
Execute composeDecode({decode, resize, normalize, hwc2chw});
|
||||
auto image = ReadFileToTensor(all_files[i]);
|
||||
composeDecode(image, img.get());
|
||||
|
||||
for (size_t k = 0; k < n_attrs; ++k) {
|
||||
gettimeofday(&start, nullptr);
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
size_t index = i * n_attrs + k;
|
||||
std::cout << static_cast<color_rep_type>(img->DataType()) << std::endl;
|
||||
inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize());
|
||||
inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(),
|
||||
all_cfg[index].Data().get(), all_cfg[index].DataSize());
|
||||
Status gen_model_ret = gen_model.Predict(inputs, &outputs);
|
||||
if (gen_model_ret != kSuccess) {
|
||||
std::cout << "Generator inference " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
Denorm(&outputs);
|
||||
int pos = all_files[i].find('.');
|
||||
std::string fileName = all_files[i].substr(0, pos);
|
||||
WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
}
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "inc/utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
std::cout << fileName << std::endl;
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
|
||||
{static_cast<int64_t>(size)}, nullptr, size);
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::vector<std::string> split(std::string inputs) {
|
||||
std::vector<std::string> line;
|
||||
std::stringstream stream(inputs);
|
||||
std::string result;
|
||||
while ( stream >> result ) {
|
||||
line.push_back(result);
|
||||
}
|
||||
return line;
|
||||
}
|
||||
|
||||
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr) {
|
||||
std::vector<mindspore::MSTensor> res;
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << " open failed." << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string n_images;
|
||||
std::string n_attrs;
|
||||
getline(ifs, n_images);
|
||||
getline(ifs, n_attrs);
|
||||
|
||||
auto n_images_ = std::stoi(n_images);
|
||||
auto n_attrs_ = std::stoi(n_attrs);
|
||||
*n_ptr = n_attrs_;
|
||||
std::cout << "Image number is " << n_images << std::endl;
|
||||
std::cout << "Attribute number is " << n_attrs << std::endl;
|
||||
|
||||
auto all_lines = n_images_ * n_attrs_;
|
||||
for (auto i = 0; i < all_lines; i++) {
|
||||
std::string val;
|
||||
getline(ifs, val);
|
||||
std::vector<std::string> val_split = split(val);
|
||||
void *data = malloc(sizeof(float)*n_attrs_);
|
||||
float *elements = reinterpret_cast<float *>(data);
|
||||
for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str());
|
||||
auto size = sizeof(float) * n_attrs_;
|
||||
mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32,
|
||||
{static_cast<int64_t>(size)}, nullptr, size);
|
||||
memcpy(buffer.MutableData(), elements, size);
|
||||
res.emplace_back(buffer);
|
||||
}
|
||||
ifs.close();
|
||||
return res;
|
||||
}
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
||||
|
||||
void Denorm(std::vector<MSTensor> *outputs) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
size_t outputSize = (*outputs)[i].DataSize();
|
||||
float* netOutput = reinterpret_cast<float *>((*outputs)[i].MutableData());
|
||||
size_t outputLen = outputSize / sizeof(float);
|
||||
|
||||
for (size_t j = 0; j < outputLen; ++j) {
|
||||
netOutput[j] = (netOutput[j] + 1) / 2 * 255;
|
||||
netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j];
|
||||
netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,12 +28,12 @@ import mindspore.dataset as de
|
|||
from mindspore import context, Tensor, ops
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Genc, Gdec
|
||||
from src.attgan import Gen
|
||||
from src.cell import init_weights
|
||||
from src.data import check_attribute_conflict
|
||||
from src.data import get_loader, Custom
|
||||
from src.helpers import Progressbar
|
||||
from src.utils import resume_model, denorm
|
||||
from src.utils import resume_generator, denorm
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
@ -45,8 +45,7 @@ def parse(arg=None):
|
|||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--num_test', dest='num_test', type=int)
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--custom_img', action='store_true')
|
||||
parser.add_argument('--custom_data', type=str, default='../data/custom')
|
||||
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
|
||||
|
@ -62,8 +61,7 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
|||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.test_int = args_.test_int
|
||||
args.num_test = args_.num_test
|
||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
||||
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||
args.custom_img = args_.custom_img
|
||||
args.custom_data = args_.custom_data
|
||||
args.custom_attr = args_.custom_attr
|
||||
|
@ -101,15 +99,15 @@ else:
|
|||
print('Testing images:', min(test_len, args.num_test))
|
||||
|
||||
# Model loader
|
||||
genc = Genc(mode='test')
|
||||
gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test')
|
||||
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test')
|
||||
|
||||
# Initialize network
|
||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
load_param_into_net(gen, para_gen)
|
||||
|
||||
print("Network initializes successfully.")
|
||||
|
||||
progressbar = Progressbar()
|
||||
it = 0
|
||||
|
@ -134,8 +132,7 @@ for data in test_dataset_iter:
|
|||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
if i > 0:
|
||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||
a_enc = genc(img_a)
|
||||
samples.append(gdec(a_enc, att_b_))
|
||||
samples.append(gen(img_a, att_b_, mode="enc-dec"))
|
||||
cat = ops.Concat(axis=3)
|
||||
samples = cat(samples).asnumpy()
|
||||
result = denorm(samples)
|
||||
|
|
|
@ -21,14 +21,13 @@ import numpy as np
|
|||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_param_into_net
|
||||
|
||||
from src.utils import resume_model
|
||||
from src.attgan import Genc, Gdec
|
||||
from src.utils import resume_generator
|
||||
from src.attgan import Gen
|
||||
|
||||
parser = argparse.ArgumentParser(description='Attribute Edit')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
|
||||
|
@ -39,27 +38,20 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
|||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.device_id = args_.device_id
|
||||
args.batch_size = args_.batch_size
|
||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
||||
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||
args.file_format = args_.file_format
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
genc = Genc(mode='test')
|
||||
gdec = Gdec(mode='test')
|
||||
gen = Gen(mode="test")
|
||||
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
load_param_into_net(gen, para_gen)
|
||||
|
||||
enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||
dec_array = genc(enc_array)
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
|
||||
|
||||
G_enc_file = f"AttGAN_Generator_Encoder"
|
||||
export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format)
|
||||
|
||||
G_dec_file = f"AttGAN_Generator_Decoder"
|
||||
export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format)
|
||||
Gen_file = f"attgan_mindir"
|
||||
export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format)
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of postprocess"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--bin_path', type=str, default='./result_Files/')
|
||||
parser.add_argument('--target_path', type=str, default='./result_Files/')
|
||||
return parser.parse_args(arg)
|
||||
|
||||
def load_bin_file(bin_file, shape=None, dtype="float32"):
|
||||
"""Load data from bin file"""
|
||||
data = np.fromfile(bin_file, dtype=dtype)
|
||||
if shape:
|
||||
data = np.reshape(data, shape)
|
||||
return data
|
||||
|
||||
def save_bin_to_image(data, out_name):
|
||||
"""Save bin file to image arrays"""
|
||||
image = np.transpose(data, (1, 2, 0))
|
||||
im = Image.fromarray(np.uint8(image))
|
||||
im.save(out_name)
|
||||
print("Successfully save image in " + out_name)
|
||||
|
||||
def scan_dir(bin_path):
|
||||
"""Scan directory"""
|
||||
out = os.listdir(bin_path)
|
||||
return out
|
||||
|
||||
def postprocess(bin_path):
|
||||
"""Post process bin file"""
|
||||
file_list = scan_dir(bin_path)
|
||||
for file in file_list:
|
||||
data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32")
|
||||
pos = file.find(".")
|
||||
file_name = file[0:pos] + "." + "jpg"
|
||||
outfile = os.path.join(args.target_path, file_name)
|
||||
save_bin_to_image(data, outfile)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse()
|
||||
postprocess(args.bin_path)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""pre process for 310 inference"""
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
selected_attrs = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of preprocess"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn')
|
||||
parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt')
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
|
||||
return parser.parse_args(arg)
|
||||
|
||||
args = parse()
|
||||
args.n_attrs = len(args.attrs)
|
||||
|
||||
def check_attribute_conflict(att_batch, att_name, att_names):
|
||||
"""Check Attributes"""
|
||||
def _set(att, att_name):
|
||||
if att_name in att_names:
|
||||
att[att_names.index(att_name)] = 0.0
|
||||
|
||||
att_id = att_names.index(att_name)
|
||||
for att in att_batch:
|
||||
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
|
||||
_set(att, 'Bangs')
|
||||
elif att_name == 'Bangs' and att[att_id] != 0:
|
||||
_set(att, 'Bald')
|
||||
_set(att, 'Receding_Hairline')
|
||||
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
|
||||
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
|
||||
for n in ['Straight_Hair', 'Wavy_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
|
||||
for n in ['Mustache', 'No_Beard']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
return att_batch
|
||||
|
||||
def read_cfg_file(attr_path):
|
||||
"""Read configuration from attribute file"""
|
||||
attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split()
|
||||
atts = [attr_list.index(att) + 1 for att in selected_attrs]
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0])
|
||||
labels = [labels] if attr_number == 1 else labels[0:]
|
||||
new_attr = []
|
||||
for index in range(attr_number):
|
||||
att = [np.asarray((labels[index] + 1) // 2)]
|
||||
new_attr.append(att)
|
||||
new_attr = np.array(new_attr)
|
||||
return new_attr, attr_number
|
||||
|
||||
def preprocess_cfg(attrs, numbers):
|
||||
"""Preprocess attribute file"""
|
||||
new_attr = []
|
||||
for index in range(numbers):
|
||||
attr = attrs[index]
|
||||
att_b_list = [attr]
|
||||
for i in range(args.n_attrs):
|
||||
tmp = attr.copy()
|
||||
tmp[:, i] = 1 - tmp[:, i]
|
||||
tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs)
|
||||
att_b_list.append(tmp)
|
||||
for i, att_b in enumerate(att_b_list):
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
if i > 0:
|
||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||
new_attr.append(att_b_)
|
||||
return new_attr
|
||||
|
||||
def write_cfg_file(attrs, numbers):
|
||||
"""Write attribute file"""
|
||||
cur_dir = os.getcwd()
|
||||
print(cur_dir)
|
||||
path = join(cur_dir, 'attrs.txt')
|
||||
with open(path, "w") as f:
|
||||
f.writelines(str(numbers))
|
||||
f.writelines("\n")
|
||||
f.writelines(str(args.n_attrs))
|
||||
f.writelines("\n")
|
||||
counts = numbers * args.n_attrs
|
||||
for index in range(counts):
|
||||
attrs_list = attrs[index][0]
|
||||
new_attrs_list = ["%s" % x for x in attrs_list]
|
||||
sequence = " ".join(new_attrs_list)
|
||||
f.writelines(sequence)
|
||||
f.writelines("\n")
|
||||
print("Generate cfg file successfully.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if args.attrs_path is None:
|
||||
print("Path is not correct!")
|
||||
attributes, n_images = read_cfg_file(args.attrs_path)
|
||||
new_attrs = preprocess_cfg(attributes, n_images)
|
||||
write_cfg_file(new_attrs, n_images)
|
|
@ -14,17 +14,16 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 5 ]
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]"
|
||||
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
experiment_name=$1
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
enc_name=$4
|
||||
dec_name=$5
|
||||
gen_ckpt_name=$4
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "The number of logical core" $cores
|
||||
|
@ -47,5 +46,4 @@ python eval.py \
|
|||
--custom_data $data_path \
|
||||
--custom_attr $attr_path \
|
||||
--custom_img \
|
||||
--enc_ckpt_name $enc_name \
|
||||
--dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
||||
--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
gen_model=$(get_real_path $1)
|
||||
attr_path=$(get_real_path $2)
|
||||
data_path=$(get_real_path $3)
|
||||
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "generator mindir name: "$gen_model
|
||||
echo "attribute file path: "$attr_path
|
||||
echo "dataset path: "$data_path
|
||||
echo "need preprocess: "$need_preprocess
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
echo "Start to preprocess attr file..."
|
||||
python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log
|
||||
echo "Attribute file generates successfully!"
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
echo "Start to compile source code..."
|
||||
cd ../ascend310_infer || exit
|
||||
bash build.sh &> build.log
|
||||
echo "Compile successfully."
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
echo "Start to execute inference..."
|
||||
../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log
|
||||
}
|
||||
|
||||
function postprocess_data()
|
||||
{
|
||||
echo "Start to postprocess image file..."
|
||||
python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/"
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess attrs failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
postprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "postprocess images failed"
|
||||
exit 1
|
||||
fi
|
|
@ -22,12 +22,15 @@ from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
|
|||
# Image size 128 x 128
|
||||
MAX_DIM = 64 * 16
|
||||
|
||||
|
||||
class Genc(nn.Cell):
|
||||
"""Generator encoder"""
|
||||
class Gen(nn.Cell):
|
||||
"""Generator"""
|
||||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
|
||||
mode='test'):
|
||||
dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu",
|
||||
n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"):
|
||||
super().__init__()
|
||||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
||||
self.inject_layers = min(inject_layers, dec_layers - 1)
|
||||
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
||||
|
||||
layers = []
|
||||
n_in = 3
|
||||
|
@ -39,27 +42,7 @@ class Genc(nn.Cell):
|
|||
n_in = n_out
|
||||
self.enc_layers = nn.CellList(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""Encoder construct"""
|
||||
z = x
|
||||
zs = []
|
||||
for layer in self.enc_layers:
|
||||
z = layer(z)
|
||||
zs.append(z)
|
||||
return zs
|
||||
|
||||
|
||||
class Gdec(nn.Cell):
|
||||
"""Generator decoder"""
|
||||
def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13,
|
||||
shortcut_layers=1, inject_layers=1, img_size=128, mode='test'):
|
||||
super().__init__()
|
||||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
||||
self.inject_layers = min(inject_layers, dec_layers - 1)
|
||||
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
||||
|
||||
layers = []
|
||||
n_in = 1024
|
||||
n_in = n_in + n_attrs # 1024 + 13
|
||||
for i in range(dec_layers):
|
||||
if i < dec_layers - 1:
|
||||
|
@ -80,7 +63,16 @@ class Gdec(nn.Cell):
|
|||
self.repeat = P.Tile()
|
||||
self.cat = P.Concat(1)
|
||||
|
||||
def construct(self, zs, a):
|
||||
def encoder(self, x):
|
||||
"""Encoder construct"""
|
||||
z = x
|
||||
zs = []
|
||||
for layer in self.enc_layers:
|
||||
z = layer(z)
|
||||
zs.append(z)
|
||||
return zs
|
||||
|
||||
def decoder(self, zs, a):
|
||||
"""Decoder construct"""
|
||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||
multiples = (1, 1, self.f_size, self.f_size)
|
||||
|
@ -100,6 +92,16 @@ class Gdec(nn.Cell):
|
|||
i = i + 1
|
||||
return z
|
||||
|
||||
def construct(self, x, a=None, mode="enc-dec"):
|
||||
result = None
|
||||
if mode == "enc-dec":
|
||||
out = self.encoder(x)
|
||||
result = self.decoder(out, a)
|
||||
if mode == "enc":
|
||||
result = self.encoder(x)
|
||||
if mode == "dec":
|
||||
result = self.decoder(x, a)
|
||||
return result
|
||||
|
||||
class Dis(nn.Cell):
|
||||
"""Discriminator"""
|
||||
|
|
|
@ -116,8 +116,7 @@ class TrainOneStepCellGen(nn.Cell):
|
|||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss, gf_loss, gc_loss, gr_loss
|
||||
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
|
||||
|
||||
|
||||
class TrainOneStepCellDis(nn.Cell):
|
||||
|
@ -153,5 +152,4 @@ class TrainOneStepCellDis(nn.Cell):
|
|||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.optimizer(grads)
|
||||
return loss, d_real_loss, d_fake_loss, dc_loss, df_gp
|
||||
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp
|
||||
|
|
|
@ -90,10 +90,9 @@ class WGANGPGradientPenalty(nn.Cell):
|
|||
class GenLoss(nn.Cell):
|
||||
"""Define total Generator loss"""
|
||||
|
||||
def __init__(self, args, encoder, decoder, discriminator):
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||
|
@ -108,9 +107,9 @@ class GenLoss(nn.Cell):
|
|||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get generator loss"""
|
||||
# generate
|
||||
zs_a = self.encoder(img_a)
|
||||
img_fake = self.decoder(zs_a, att_b_)
|
||||
img_recon = self.decoder(zs_a, att_a_)
|
||||
zs_a = self.generator(img_a, mode="enc")
|
||||
img_fake = self.generator(zs_a, att_b_, mode="dec")
|
||||
img_recon = self.generator(zs_a, att_a_, mode="dec")
|
||||
|
||||
# discriminate
|
||||
d_fake, dc_fake = self.discriminator(img_fake)
|
||||
|
@ -128,10 +127,9 @@ class GenLoss(nn.Cell):
|
|||
class DisLoss(nn.Cell):
|
||||
"""Define total discriminator loss"""
|
||||
|
||||
def __init__(self, args, encoder, decoder, discriminator):
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
|
@ -146,8 +144,7 @@ class DisLoss(nn.Cell):
|
|||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get discriminator loss"""
|
||||
# generate
|
||||
z = self.encoder(img_a)
|
||||
img_fake = self.decoder(z, att_b_)
|
||||
img_fake = self.generator(img_a, att_b_, mode="enc-dec")
|
||||
|
||||
# discriminate
|
||||
d_real, dc_real = self.discriminator(img_a)
|
||||
|
|
|
@ -60,16 +60,13 @@ class DistributedSampler:
|
|||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name):
|
||||
"""Restore the trained generator and discriminator"""
|
||||
def resume_generator(args, generator, gen_ckpt_name):
|
||||
"""Restore the trained generator"""
|
||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||
encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name)
|
||||
decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name)
|
||||
generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name)
|
||||
param_generator = load_checkpoint(generator_path, generator)
|
||||
|
||||
param_encoder = load_checkpoint(encoder_path, encoder)
|
||||
param_decoder = load_checkpoint(decoder_path, decoder)
|
||||
|
||||
return param_encoder, param_decoder
|
||||
return param_generator
|
||||
|
||||
def resume_discriminator(args, discriminator, dis_ckpt_name):
|
||||
"""Restore the trained discriminator"""
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================sss
|
||||
# ============================================================================
|
||||
"""Entry point for training AttGAN network"""
|
||||
|
||||
import argparse
|
||||
|
@ -32,12 +32,12 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Genc, Gdec, Dis
|
||||
from src.attgan import Gen, Dis
|
||||
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
|
||||
from src.data import data_loader
|
||||
from src.helpers import Progressbar
|
||||
from src.loss import GenLoss, DisLoss
|
||||
from src.utils import resume_model, resume_discriminator
|
||||
from src.utils import resume_generator, resume_discriminator
|
||||
|
||||
attrs_default = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
|
@ -94,8 +94,7 @@ def parse(arg=None):
|
|||
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
|
||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||
parser.add_argument('--resume_model', action='store_true')
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dis_ckpt_name', type=str, default='')
|
||||
|
||||
return parser.parse_args(arg)
|
||||
|
@ -150,32 +149,28 @@ if __name__ == '__main__':
|
|||
print('Training images:', train_length)
|
||||
|
||||
# Define network
|
||||
genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train')
|
||||
gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers,
|
||||
args.inject_layers, args.img_size, mode='train')
|
||||
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train')
|
||||
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
|
||||
args.dis_layers, args.img_size, mode='train')
|
||||
|
||||
# Initialize network
|
||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(dis, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
# Resume from checkpoint
|
||||
if args.resume_model:
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
load_param_into_net(gen, para_gen)
|
||||
load_param_into_net(dis, para_dis)
|
||||
|
||||
# Define network with loss
|
||||
G_loss_cell = GenLoss(args, genc, gdec, dis)
|
||||
D_loss_cell = DisLoss(args, genc, gdec, dis)
|
||||
G_loss_cell = GenLoss(args, gen, dis)
|
||||
D_loss_cell = DisLoss(args, gen, dis)
|
||||
|
||||
# Define Optimizer
|
||||
G_trainable_params = genc.trainable_params() + gdec.trainable_params()
|
||||
optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
|
||||
# Define One Step Train
|
||||
|
@ -193,21 +188,14 @@ if __name__ == '__main__':
|
|||
|
||||
if rank == 0:
|
||||
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
|
||||
ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder')
|
||||
ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder')
|
||||
ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator')
|
||||
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
|
||||
|
||||
cb_params_genc = _InternalCallbackParam()
|
||||
cb_params_genc.train_network = genc
|
||||
cb_params_genc.cur_epoch_num = 0
|
||||
genc_run_context = RunContext(cb_params_genc)
|
||||
ckpt_cb_genc.begin(genc_run_context)
|
||||
|
||||
cb_params_gdec = _InternalCallbackParam()
|
||||
cb_params_gdec.train_network = gdec
|
||||
cb_params_gdec.cur_epoch_num = 0
|
||||
gdec_run_context = RunContext(cb_params_gdec)
|
||||
ckpt_cb_gdec.begin(gdec_run_context)
|
||||
cb_params_gen = _InternalCallbackParam()
|
||||
cb_params_gen.train_network = gen
|
||||
cb_params_gen.cur_epoch_num = 0
|
||||
gen_run_context = RunContext(cb_params_gen)
|
||||
ckpt_cb_gen.begin(gen_run_context)
|
||||
|
||||
cb_params_dis = _InternalCallbackParam()
|
||||
cb_params_dis.train_network = dis
|
||||
|
@ -243,16 +231,12 @@ if __name__ == '__main__':
|
|||
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
|
||||
|
||||
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
|
||||
cb_params_genc.cur_epoch_num = epoch + 1
|
||||
cb_params_gdec.cur_epoch_num = epoch + 1
|
||||
cb_params_gen.cur_epoch_num = epoch + 1
|
||||
cb_params_dis.cur_epoch_num = epoch + 1
|
||||
cb_params_genc.cur_step_num = it + 1
|
||||
cb_params_gdec.cur_step_num = it + 1
|
||||
cb_params_gen.cur_step_num = it + 1
|
||||
cb_params_dis.cur_step_num = it + 1
|
||||
cb_params_genc.batch_num = it + 2
|
||||
cb_params_gdec.batch_num = it + 2
|
||||
cb_params_gen.batch_num = it + 2
|
||||
cb_params_dis.batch_num = it + 2
|
||||
ckpt_cb_genc.step_end(genc_run_context)
|
||||
ckpt_cb_gdec.step_end(gdec_run_context)
|
||||
ckpt_cb_gen.step_end(gen_run_context)
|
||||
ckpt_cb_dis.step_end(dis_run_context)
|
||||
it += 1
|
||||
|
|
Loading…
Reference in New Issue