forked from mindspore-Ecosystem/mindspore
commit
fc746f737c
|
@ -17,6 +17,8 @@
|
||||||
- [评估](#评估)
|
- [评估](#评估)
|
||||||
- [推理过程](#推理过程)
|
- [推理过程](#推理过程)
|
||||||
- [导出MindIR](#导出MindIR)
|
- [导出MindIR](#导出MindIR)
|
||||||
|
- [在Ascend310执行推理](#在Ascend310执行推理)
|
||||||
|
- [结果](#结果)
|
||||||
- [模型描述](#模型描述)
|
- [模型描述](#模型描述)
|
||||||
- [性能](#性能)
|
- [性能](#性能)
|
||||||
- [评估性能](#评估性能)
|
- [评估性能](#评估性能)
|
||||||
|
@ -53,7 +55,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
- 硬件(Ascend)
|
- 硬件(Ascend)
|
||||||
- 使用Ascend来搭建硬件环境。
|
- 使用Ascend来搭建硬件环境。
|
||||||
- 框架
|
- 框架
|
||||||
- [MindSpore](https://www.mindspore.cn/install)
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
- 如需查看详情,请参见如下资源:
|
- 如需查看详情,请参见如下资源:
|
||||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||||
|
@ -70,17 +72,17 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||||
OR
|
OR
|
||||||
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba
|
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||||
|
|
||||||
# 运行分布式训练示例
|
# 运行分布式训练示例
|
||||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||||
|
|
||||||
# 运行评估示例
|
# 运行评估示例
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||||
OR
|
OR
|
||||||
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name
|
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name
|
||||||
```
|
```
|
||||||
|
|
||||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
|
对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
|
||||||
|
@ -90,7 +92,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||||
|
|
||||||
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
|
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
|
||||||
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将两个检查点文件(Encoder和Decoder)的名称作为参数传入。
|
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将检查点文件(Generator)的名称作为参数传入。
|
||||||
|
|
||||||
[注意] 以上路径均应设置为绝对路径
|
[注意] 以上路径均应设置为绝对路径
|
||||||
|
|
||||||
|
@ -102,10 +104,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
.
|
.
|
||||||
└─ cv
|
└─ cv
|
||||||
└─ AttGAN
|
└─ AttGAN
|
||||||
|
├── ascend310_infer # 310推理目录
|
||||||
├── scripts
|
├── scripts
|
||||||
├──run_distribute_train.sh # 分布式训练的shell脚本
|
├──run_distribute_train.sh # 分布式训练的shell脚本
|
||||||
├──run_single_train.sh # 单卡训练的shell脚本
|
├──run_single_train.sh # 单卡训练的shell脚本
|
||||||
├──run_eval.sh # 推理脚本
|
├──run_eval.sh # 评估脚本
|
||||||
|
├──run_infer_310.sh # 推理脚本
|
||||||
├─ src
|
├─ src
|
||||||
├─ __init__.py # 初始化文件
|
├─ __init__.py # 初始化文件
|
||||||
├─ block.py # 基础cell
|
├─ block.py # 基础cell
|
||||||
|
@ -117,6 +121,9 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
├─ loss.py # loss计算
|
├─ loss.py # loss计算
|
||||||
├─ eval.py # 测试脚本
|
├─ eval.py # 测试脚本
|
||||||
├─ train.py # 训练脚本
|
├─ train.py # 训练脚本
|
||||||
|
├─ export.py # MINDIR模型导出脚本
|
||||||
|
├─ preprocess.py # 310推理预处理脚本
|
||||||
|
├─ postprocess.py # 310推理后处理脚本
|
||||||
└─ README_CN.md # AttGAN的文件描述
|
└─ README_CN.md # AttGAN的文件描述
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -141,7 +148,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
- Ascend处理器环境运行
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
|
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
|
||||||
|
@ -153,12 +160,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
- 在Ascend环境运行时评估自定义数据集
|
- 在Ascend环境运行时评估自定义数据集
|
||||||
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。
|
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。
|
||||||
|
|
||||||
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数)
|
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
|
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
|
||||||
|
@ -168,7 +175,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据
|
||||||
### 导出MindIR
|
### 导出MindIR
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT]
|
python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT]
|
||||||
```
|
```
|
||||||
|
|
||||||
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
|
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
|
||||||
|
@ -176,6 +183,26 @@ python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CK
|
||||||
|
|
||||||
脚本会在当前目录下生成对应的MINDIR文件。
|
脚本会在当前目录下生成对应的MINDIR文件。
|
||||||
|
|
||||||
|
### 在Ascend310执行推理
|
||||||
|
|
||||||
|
在执行推理前,必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `MINDIR_PATH` MINDIR文件的路径
|
||||||
|
- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径
|
||||||
|
- `DATA_PATH` 需要进行推理的数据集目录,图像格式应当为jpg
|
||||||
|
- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理,可以在y或者n中选择,如果选择y,表示进行预处理(在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间)
|
||||||
|
- `DEVICE_ID` 可选,默认值为0.
|
||||||
|
|
||||||
|
[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件,第一行为要推理的图片数目,第二行为要编辑的属性,接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。
|
||||||
|
|
||||||
|
### 结果
|
||||||
|
|
||||||
|
推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果,是否对该属性进行编辑根据属性编辑文件的内容决定。
|
||||||
|
|
||||||
# 模型描述
|
# 模型描述
|
||||||
|
|
||||||
## 性能
|
## 性能
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
cmake_minimum_required(VERSION 3.14.1)
|
||||||
|
project(Ascend310Infer)
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||||
|
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||||
|
option(MINDSPORE_PATH "mindspore install path" "")
|
||||||
|
include_directories(${MINDSPORE_PATH})
|
||||||
|
include_directories(${MINDSPORE_PATH}/include)
|
||||||
|
include_directories(${PROJECT_SRC_ROOT})
|
||||||
|
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||||
|
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||||
|
|
||||||
|
add_executable(main src/main.cc src/utils.cc)
|
||||||
|
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
if [ -d out ]; then
|
||||||
|
rm -rf out
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir out
|
||||||
|
cd out || exit
|
||||||
|
|
||||||
|
if [ -f "Makefile" ]; then
|
||||||
|
make clean
|
||||||
|
fi
|
||||||
|
|
||||||
|
cmake .. \
|
||||||
|
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||||
|
make
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
|
||||||
|
DIR *OpenDir(std::string_view dirName);
|
||||||
|
void Denorm(std::vector<mindspore::MSTensor> *outputs);
|
||||||
|
std::string RealPath(std::string_view path);
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||||
|
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr);
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||||
|
#endif
|
|
@ -0,0 +1,149 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <sys/time.h>
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/dataset/vision_ascend.h"
|
||||||
|
#include "include/dataset/execute.h"
|
||||||
|
#include "include/dataset/vision.h"
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
using mindspore::Context;
|
||||||
|
using mindspore::Serialization;
|
||||||
|
using mindspore::Model;
|
||||||
|
using mindspore::Status;
|
||||||
|
using mindspore::ModelType;
|
||||||
|
using mindspore::GraphCell;
|
||||||
|
using mindspore::kSuccess;
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::dataset::Execute;
|
||||||
|
using mindspore::dataset::TensorTransform;
|
||||||
|
using mindspore::dataset::vision::Resize;
|
||||||
|
using mindspore::dataset::vision::HWC2CHW;
|
||||||
|
using mindspore::dataset::vision::Normalize;
|
||||||
|
using mindspore::dataset::vision::Decode;
|
||||||
|
using color_rep_type = std::underlying_type<mindspore::DataType>::type;
|
||||||
|
|
||||||
|
DEFINE_string(gen_mindir_path, "", "generator mindir path");
|
||||||
|
DEFINE_string(dataset_path, "", "dataset path");
|
||||||
|
DEFINE_string(attr_file_path, "", "attribute file path");
|
||||||
|
DEFINE_int32(device_id, 0, "device id");
|
||||||
|
DEFINE_int32(image_height, 128, "image height");
|
||||||
|
DEFINE_int32(image_width, 128, "image width");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
if (RealPath(FLAGS_gen_mindir_path).empty()) {
|
||||||
|
std::cout << "Invalid generator mindir" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310->SetDeviceID(FLAGS_device_id);
|
||||||
|
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||||
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
|
|
||||||
|
mindspore::Graph gen_graph;
|
||||||
|
Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph);
|
||||||
|
|
||||||
|
Model gen_model;
|
||||||
|
Status gen_ret = gen_model.Build(GraphCell(gen_graph), context);
|
||||||
|
|
||||||
|
if (gen_ret != kSuccess) {
|
||||||
|
std::cout << "ERROR: Generator build failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_attrs = 13;
|
||||||
|
auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs);
|
||||||
|
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||||
|
std::map<double, double> costTime_map;
|
||||||
|
double startTimeMs;
|
||||||
|
double endTimeMs;
|
||||||
|
size_t size = all_files.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
struct timeval start = {0};
|
||||||
|
struct timeval end = {0};
|
||||||
|
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||||
|
|
||||||
|
auto img = std::make_shared<MSTensor>();
|
||||||
|
std::shared_ptr<TensorTransform> decode(new Decode());
|
||||||
|
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
|
||||||
|
auto resizeShape = {FLAGS_image_height, FLAGS_image_width};
|
||||||
|
std::shared_ptr<TensorTransform> resize(new Resize(resizeShape));
|
||||||
|
std::shared_ptr<TensorTransform> normalize(
|
||||||
|
new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
|
||||||
|
Execute composeDecode({decode, resize, normalize, hwc2chw});
|
||||||
|
auto image = ReadFileToTensor(all_files[i]);
|
||||||
|
composeDecode(image, img.get());
|
||||||
|
|
||||||
|
for (size_t k = 0; k < n_attrs; ++k) {
|
||||||
|
gettimeofday(&start, nullptr);
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
size_t index = i * n_attrs + k;
|
||||||
|
std::cout << static_cast<color_rep_type>(img->DataType()) << std::endl;
|
||||||
|
inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize());
|
||||||
|
inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(),
|
||||||
|
all_cfg[index].Data().get(), all_cfg[index].DataSize());
|
||||||
|
Status gen_model_ret = gen_model.Predict(inputs, &outputs);
|
||||||
|
if (gen_model_ret != kSuccess) {
|
||||||
|
std::cout << "Generator inference " << all_files[i] << " failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
Denorm(&outputs);
|
||||||
|
int pos = all_files[i].find('.');
|
||||||
|
std::string fileName = all_files[i].substr(0, pos);
|
||||||
|
WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs);
|
||||||
|
gettimeofday(&end, nullptr);
|
||||||
|
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||||
|
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||||
|
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double average = 0.0;
|
||||||
|
int inferCount = 0;
|
||||||
|
|
||||||
|
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||||
|
double diff = 0.0;
|
||||||
|
diff = iter->second - iter->first;
|
||||||
|
average += diff;
|
||||||
|
inferCount++;
|
||||||
|
}
|
||||||
|
average = average / inferCount;
|
||||||
|
std::stringstream timeCost;
|
||||||
|
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||||
|
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||||
|
fileStream << timeCost.str();
|
||||||
|
fileStream.close();
|
||||||
|
costTime_map.clear();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,201 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::DataType;
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||||
|
struct dirent *filename;
|
||||||
|
DIR *dir = OpenDir(dirName);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<std::string> res;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string dName = std::string(filename->d_name);
|
||||||
|
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||||
|
}
|
||||||
|
std::sort(res.begin(), res.end());
|
||||||
|
for (auto &f : res) {
|
||||||
|
std::cout << "image file: " << f << std::endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||||
|
std::string homePath = "./result_Files";
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
size_t outputSize;
|
||||||
|
std::shared_ptr<const void> netOutput;
|
||||||
|
netOutput = outputs[i].Data();
|
||||||
|
outputSize = outputs[i].DataSize();
|
||||||
|
int pos = imageFile.rfind('/');
|
||||||
|
std::string fileName(imageFile, pos + 1);
|
||||||
|
std::cout << fileName << std::endl;
|
||||||
|
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin");
|
||||||
|
std::string outFileName = homePath + "/" + fileName;
|
||||||
|
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||||
|
fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
|
||||||
|
fclose(outputFile);
|
||||||
|
outputFile = nullptr;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||||
|
if (file.empty()) {
|
||||||
|
std::cout << "Pointer file is nullptr" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream ifs(file);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "File: " << file << "open failed" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
size_t size = ifs.tellg();
|
||||||
|
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
|
||||||
|
{static_cast<int64_t>(size)}, nullptr, size);
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||||
|
ifs.close();
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> split(std::string inputs) {
|
||||||
|
std::vector<std::string> line;
|
||||||
|
std::stringstream stream(inputs);
|
||||||
|
std::string result;
|
||||||
|
while ( stream >> result ) {
|
||||||
|
line.push_back(result);
|
||||||
|
}
|
||||||
|
return line;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mindspore::MSTensor> ReadCfgToTensor(const std::string &file, size_t *n_ptr) {
|
||||||
|
std::vector<mindspore::MSTensor> res;
|
||||||
|
if (file.empty()) {
|
||||||
|
std::cout << "Pointer file is nullptr." << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream ifs(file);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "File: " << file << " is not exist." << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "File: " << file << " open failed." << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string n_images;
|
||||||
|
std::string n_attrs;
|
||||||
|
getline(ifs, n_images);
|
||||||
|
getline(ifs, n_attrs);
|
||||||
|
|
||||||
|
auto n_images_ = std::stoi(n_images);
|
||||||
|
auto n_attrs_ = std::stoi(n_attrs);
|
||||||
|
*n_ptr = n_attrs_;
|
||||||
|
std::cout << "Image number is " << n_images << std::endl;
|
||||||
|
std::cout << "Attribute number is " << n_attrs << std::endl;
|
||||||
|
|
||||||
|
auto all_lines = n_images_ * n_attrs_;
|
||||||
|
for (auto i = 0; i < all_lines; i++) {
|
||||||
|
std::string val;
|
||||||
|
getline(ifs, val);
|
||||||
|
std::vector<std::string> val_split = split(val);
|
||||||
|
void *data = malloc(sizeof(float)*n_attrs_);
|
||||||
|
float *elements = reinterpret_cast<float *>(data);
|
||||||
|
for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str());
|
||||||
|
auto size = sizeof(float) * n_attrs_;
|
||||||
|
mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32,
|
||||||
|
{static_cast<int64_t>(size)}, nullptr, size);
|
||||||
|
memcpy(buffer.MutableData(), elements, size);
|
||||||
|
res.emplace_back(buffer);
|
||||||
|
}
|
||||||
|
ifs.close();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
DIR *OpenDir(std::string_view dirName) {
|
||||||
|
if (dirName.empty()) {
|
||||||
|
std::cout << " dirName is null ! " << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string realPath = RealPath(dirName);
|
||||||
|
struct stat s;
|
||||||
|
lstat(realPath.c_str(), &s);
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DIR *dir;
|
||||||
|
dir = opendir(realPath.c_str());
|
||||||
|
if (dir == nullptr) {
|
||||||
|
std::cout << "Can not open dir " << dirName << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||||
|
return dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RealPath(std::string_view path) {
|
||||||
|
char realPathMem[PATH_MAX] = {0};
|
||||||
|
char *realPathRet = nullptr;
|
||||||
|
realPathRet = realpath(path.data(), realPathMem);
|
||||||
|
|
||||||
|
if (realPathRet == nullptr) {
|
||||||
|
std::cout << "File: " << path << " is not exist.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string realPath(realPathMem);
|
||||||
|
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||||
|
return realPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Denorm(std::vector<MSTensor> *outputs) {
|
||||||
|
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||||
|
size_t outputSize = (*outputs)[i].DataSize();
|
||||||
|
float* netOutput = reinterpret_cast<float *>((*outputs)[i].MutableData());
|
||||||
|
size_t outputLen = outputSize / sizeof(float);
|
||||||
|
|
||||||
|
for (size_t j = 0; j < outputLen; ++j) {
|
||||||
|
netOutput[j] = (netOutput[j] + 1) / 2 * 255;
|
||||||
|
netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j];
|
||||||
|
netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -28,12 +28,12 @@ import mindspore.dataset as de
|
||||||
from mindspore import context, Tensor, ops
|
from mindspore import context, Tensor, ops
|
||||||
from mindspore.train.serialization import load_param_into_net
|
from mindspore.train.serialization import load_param_into_net
|
||||||
|
|
||||||
from src.attgan import Genc, Gdec
|
from src.attgan import Gen
|
||||||
from src.cell import init_weights
|
from src.cell import init_weights
|
||||||
from src.data import check_attribute_conflict
|
from src.data import check_attribute_conflict
|
||||||
from src.data import get_loader, Custom
|
from src.data import get_loader, Custom
|
||||||
from src.helpers import Progressbar
|
from src.helpers import Progressbar
|
||||||
from src.utils import resume_model, denorm
|
from src.utils import resume_generator, denorm
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||||
|
@ -45,8 +45,7 @@ def parse(arg=None):
|
||||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||||
parser.add_argument('--num_test', dest='num_test', type=int)
|
parser.add_argument('--num_test', dest='num_test', type=int)
|
||||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
|
||||||
parser.add_argument('--custom_img', action='store_true')
|
parser.add_argument('--custom_img', action='store_true')
|
||||||
parser.add_argument('--custom_data', type=str, default='../data/custom')
|
parser.add_argument('--custom_data', type=str, default='../data/custom')
|
||||||
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
|
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
|
||||||
|
@ -62,8 +61,7 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||||
args.test_int = args_.test_int
|
args.test_int = args_.test_int
|
||||||
args.num_test = args_.num_test
|
args.num_test = args_.num_test
|
||||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
|
||||||
args.custom_img = args_.custom_img
|
args.custom_img = args_.custom_img
|
||||||
args.custom_data = args_.custom_data
|
args.custom_data = args_.custom_data
|
||||||
args.custom_attr = args_.custom_attr
|
args.custom_attr = args_.custom_attr
|
||||||
|
@ -101,15 +99,15 @@ else:
|
||||||
print('Testing images:', min(test_len, args.num_test))
|
print('Testing images:', min(test_len, args.num_test))
|
||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
genc = Genc(mode='test')
|
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||||
gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test')
|
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test')
|
||||||
|
|
||||||
# Initialize network
|
# Initialize network
|
||||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
load_param_into_net(gen, para_gen)
|
||||||
load_param_into_net(genc, para_genc)
|
|
||||||
load_param_into_net(gdec, para_gdec)
|
print("Network initializes successfully.")
|
||||||
|
|
||||||
progressbar = Progressbar()
|
progressbar = Progressbar()
|
||||||
it = 0
|
it = 0
|
||||||
|
@ -134,8 +132,7 @@ for data in test_dataset_iter:
|
||||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||||
if i > 0:
|
if i > 0:
|
||||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||||
a_enc = genc(img_a)
|
samples.append(gen(img_a, att_b_, mode="enc-dec"))
|
||||||
samples.append(gdec(a_enc, att_b_))
|
|
||||||
cat = ops.Concat(axis=3)
|
cat = ops.Concat(axis=3)
|
||||||
samples = cat(samples).asnumpy()
|
samples = cat(samples).asnumpy()
|
||||||
result = denorm(samples)
|
result = denorm(samples)
|
||||||
|
|
|
@ -21,14 +21,13 @@ import numpy as np
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import export, load_param_into_net
|
from mindspore.train.serialization import export, load_param_into_net
|
||||||
|
|
||||||
from src.utils import resume_model
|
from src.utils import resume_generator
|
||||||
from src.attgan import Genc, Gdec
|
from src.attgan import Gen
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Attribute Edit')
|
parser = argparse.ArgumentParser(description='Attribute Edit')
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
|
||||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||||
|
|
||||||
|
@ -39,27 +38,20 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||||
args.device_id = args_.device_id
|
args.device_id = args_.device_id
|
||||||
args.batch_size = args_.batch_size
|
args.batch_size = args_.batch_size
|
||||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
args.gen_ckpt_name = args_.gen_ckpt_name
|
||||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
|
||||||
args.file_format = args_.file_format
|
args.file_format = args_.file_format
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
genc = Genc(mode='test')
|
gen = Gen(mode="test")
|
||||||
gdec = Gdec(mode='test')
|
|
||||||
|
|
||||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||||
load_param_into_net(genc, para_genc)
|
load_param_into_net(gen, para_gen)
|
||||||
load_param_into_net(gdec, para_gdec)
|
|
||||||
|
|
||||||
enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||||
dec_array = genc(enc_array)
|
|
||||||
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
|
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
|
||||||
|
|
||||||
G_enc_file = f"AttGAN_Generator_Encoder"
|
Gen_file = f"attgan_mindir"
|
||||||
export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format)
|
export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format)
|
||||||
|
|
||||||
G_dec_file = f"AttGAN_Generator_Decoder"
|
|
||||||
export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format)
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""post process for 310 inference"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def parse(arg=None):
|
||||||
|
"""Define configuration of postprocess"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--bin_path', type=str, default='./result_Files/')
|
||||||
|
parser.add_argument('--target_path', type=str, default='./result_Files/')
|
||||||
|
return parser.parse_args(arg)
|
||||||
|
|
||||||
|
def load_bin_file(bin_file, shape=None, dtype="float32"):
|
||||||
|
"""Load data from bin file"""
|
||||||
|
data = np.fromfile(bin_file, dtype=dtype)
|
||||||
|
if shape:
|
||||||
|
data = np.reshape(data, shape)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def save_bin_to_image(data, out_name):
|
||||||
|
"""Save bin file to image arrays"""
|
||||||
|
image = np.transpose(data, (1, 2, 0))
|
||||||
|
im = Image.fromarray(np.uint8(image))
|
||||||
|
im.save(out_name)
|
||||||
|
print("Successfully save image in " + out_name)
|
||||||
|
|
||||||
|
def scan_dir(bin_path):
|
||||||
|
"""Scan directory"""
|
||||||
|
out = os.listdir(bin_path)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def postprocess(bin_path):
|
||||||
|
"""Post process bin file"""
|
||||||
|
file_list = scan_dir(bin_path)
|
||||||
|
for file in file_list:
|
||||||
|
data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32")
|
||||||
|
pos = file.find(".")
|
||||||
|
file_name = file[0:pos] + "." + "jpg"
|
||||||
|
outfile = os.path.join(args.target_path, file_name)
|
||||||
|
save_bin_to_image(data, outfile)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
args = parse()
|
||||||
|
postprocess(args.bin_path)
|
|
@ -0,0 +1,123 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""pre process for 310 inference"""
|
||||||
|
import os
|
||||||
|
from os.path import join
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
selected_attrs = [
|
||||||
|
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||||
|
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||||
|
]
|
||||||
|
|
||||||
|
def parse(arg=None):
|
||||||
|
"""Define configuration of preprocess"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn')
|
||||||
|
parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt')
|
||||||
|
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||||
|
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
|
||||||
|
return parser.parse_args(arg)
|
||||||
|
|
||||||
|
args = parse()
|
||||||
|
args.n_attrs = len(args.attrs)
|
||||||
|
|
||||||
|
def check_attribute_conflict(att_batch, att_name, att_names):
|
||||||
|
"""Check Attributes"""
|
||||||
|
def _set(att, att_name):
|
||||||
|
if att_name in att_names:
|
||||||
|
att[att_names.index(att_name)] = 0.0
|
||||||
|
|
||||||
|
att_id = att_names.index(att_name)
|
||||||
|
for att in att_batch:
|
||||||
|
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
|
||||||
|
_set(att, 'Bangs')
|
||||||
|
elif att_name == 'Bangs' and att[att_id] != 0:
|
||||||
|
_set(att, 'Bald')
|
||||||
|
_set(att, 'Receding_Hairline')
|
||||||
|
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
|
||||||
|
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||||
|
if n != att_name:
|
||||||
|
_set(att, n)
|
||||||
|
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
|
||||||
|
for n in ['Straight_Hair', 'Wavy_Hair']:
|
||||||
|
if n != att_name:
|
||||||
|
_set(att, n)
|
||||||
|
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
|
||||||
|
for n in ['Mustache', 'No_Beard']:
|
||||||
|
if n != att_name:
|
||||||
|
_set(att, n)
|
||||||
|
return att_batch
|
||||||
|
|
||||||
|
def read_cfg_file(attr_path):
|
||||||
|
"""Read configuration from attribute file"""
|
||||||
|
attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split()
|
||||||
|
atts = [attr_list.index(att) + 1 for att in selected_attrs]
|
||||||
|
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||||
|
attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0])
|
||||||
|
labels = [labels] if attr_number == 1 else labels[0:]
|
||||||
|
new_attr = []
|
||||||
|
for index in range(attr_number):
|
||||||
|
att = [np.asarray((labels[index] + 1) // 2)]
|
||||||
|
new_attr.append(att)
|
||||||
|
new_attr = np.array(new_attr)
|
||||||
|
return new_attr, attr_number
|
||||||
|
|
||||||
|
def preprocess_cfg(attrs, numbers):
|
||||||
|
"""Preprocess attribute file"""
|
||||||
|
new_attr = []
|
||||||
|
for index in range(numbers):
|
||||||
|
attr = attrs[index]
|
||||||
|
att_b_list = [attr]
|
||||||
|
for i in range(args.n_attrs):
|
||||||
|
tmp = attr.copy()
|
||||||
|
tmp[:, i] = 1 - tmp[:, i]
|
||||||
|
tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs)
|
||||||
|
att_b_list.append(tmp)
|
||||||
|
for i, att_b in enumerate(att_b_list):
|
||||||
|
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||||
|
if i > 0:
|
||||||
|
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||||
|
new_attr.append(att_b_)
|
||||||
|
return new_attr
|
||||||
|
|
||||||
|
def write_cfg_file(attrs, numbers):
|
||||||
|
"""Write attribute file"""
|
||||||
|
cur_dir = os.getcwd()
|
||||||
|
print(cur_dir)
|
||||||
|
path = join(cur_dir, 'attrs.txt')
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.writelines(str(numbers))
|
||||||
|
f.writelines("\n")
|
||||||
|
f.writelines(str(args.n_attrs))
|
||||||
|
f.writelines("\n")
|
||||||
|
counts = numbers * args.n_attrs
|
||||||
|
for index in range(counts):
|
||||||
|
attrs_list = attrs[index][0]
|
||||||
|
new_attrs_list = ["%s" % x for x in attrs_list]
|
||||||
|
sequence = " ".join(new_attrs_list)
|
||||||
|
f.writelines(sequence)
|
||||||
|
f.writelines("\n")
|
||||||
|
print("Generate cfg file successfully.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if args.attrs_path is None:
|
||||||
|
print("Path is not correct!")
|
||||||
|
attributes, n_images = read_cfg_file(args.attrs_path)
|
||||||
|
new_attrs = preprocess_cfg(attributes, n_images)
|
||||||
|
write_cfg_file(new_attrs, n_images)
|
|
@ -14,17 +14,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
if [ $# != 5 ]
|
if [ $# != 4 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]"
|
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
experiment_name=$1
|
experiment_name=$1
|
||||||
data_path=$2
|
data_path=$2
|
||||||
attr_path=$3
|
attr_path=$3
|
||||||
enc_name=$4
|
gen_ckpt_name=$4
|
||||||
dec_name=$5
|
|
||||||
|
|
||||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||||
echo "The number of logical core" $cores
|
echo "The number of logical core" $cores
|
||||||
|
@ -47,5 +46,4 @@ python eval.py \
|
||||||
--custom_data $data_path \
|
--custom_data $data_path \
|
||||||
--custom_attr $attr_path \
|
--custom_attr $attr_path \
|
||||||
--custom_img \
|
--custom_img \
|
||||||
--enc_ckpt_name $enc_name \
|
--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
||||||
--dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||||
|
echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||||
|
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||||
|
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
gen_model=$(get_real_path $1)
|
||||||
|
attr_path=$(get_real_path $2)
|
||||||
|
data_path=$(get_real_path $3)
|
||||||
|
|
||||||
|
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||||
|
need_preprocess=$4
|
||||||
|
else
|
||||||
|
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
device_id=0
|
||||||
|
if [ $# == 5 ]; then
|
||||||
|
device_id=$5
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "generator mindir name: "$gen_model
|
||||||
|
echo "attribute file path: "$attr_path
|
||||||
|
echo "dataset path: "$data_path
|
||||||
|
echo "need preprocess: "$need_preprocess
|
||||||
|
echo "device id: "$device_id
|
||||||
|
|
||||||
|
export ASCEND_HOME=/usr/local/Ascend/
|
||||||
|
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||||
|
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||||
|
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||||
|
else
|
||||||
|
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||||
|
fi
|
||||||
|
|
||||||
|
function preprocess_data()
|
||||||
|
{
|
||||||
|
echo "Start to preprocess attr file..."
|
||||||
|
python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log
|
||||||
|
echo "Attribute file generates successfully!"
|
||||||
|
}
|
||||||
|
|
||||||
|
function compile_app()
|
||||||
|
{
|
||||||
|
echo "Start to compile source code..."
|
||||||
|
cd ../ascend310_infer || exit
|
||||||
|
bash build.sh &> build.log
|
||||||
|
echo "Compile successfully."
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer()
|
||||||
|
{
|
||||||
|
cd - || exit
|
||||||
|
if [ -d result_Files ]; then
|
||||||
|
rm -rf ./result_Files
|
||||||
|
fi
|
||||||
|
if [ -d time_Result ]; then
|
||||||
|
rm -rf ./time_Result
|
||||||
|
fi
|
||||||
|
mkdir result_Files
|
||||||
|
mkdir time_Result
|
||||||
|
echo "Start to execute inference..."
|
||||||
|
../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log
|
||||||
|
}
|
||||||
|
|
||||||
|
function postprocess_data()
|
||||||
|
{
|
||||||
|
echo "Start to postprocess image file..."
|
||||||
|
python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $need_preprocess == "y" ]; then
|
||||||
|
preprocess_data
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "preprocess attrs failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
compile_app
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "compile app code failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
infer
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "execute inference failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
postprocess_data
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "postprocess images failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
|
@ -22,12 +22,15 @@ from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
|
||||||
# Image size 128 x 128
|
# Image size 128 x 128
|
||||||
MAX_DIM = 64 * 16
|
MAX_DIM = 64 * 16
|
||||||
|
|
||||||
|
class Gen(nn.Cell):
|
||||||
class Genc(nn.Cell):
|
"""Generator"""
|
||||||
"""Generator encoder"""
|
|
||||||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
|
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
|
||||||
mode='test'):
|
dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu",
|
||||||
|
n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
||||||
|
self.inject_layers = min(inject_layers, dec_layers - 1)
|
||||||
|
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
n_in = 3
|
n_in = 3
|
||||||
|
@ -39,27 +42,7 @@ class Genc(nn.Cell):
|
||||||
n_in = n_out
|
n_in = n_out
|
||||||
self.enc_layers = nn.CellList(layers)
|
self.enc_layers = nn.CellList(layers)
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
"""Encoder construct"""
|
|
||||||
z = x
|
|
||||||
zs = []
|
|
||||||
for layer in self.enc_layers:
|
|
||||||
z = layer(z)
|
|
||||||
zs.append(z)
|
|
||||||
return zs
|
|
||||||
|
|
||||||
|
|
||||||
class Gdec(nn.Cell):
|
|
||||||
"""Generator decoder"""
|
|
||||||
def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13,
|
|
||||||
shortcut_layers=1, inject_layers=1, img_size=128, mode='test'):
|
|
||||||
super().__init__()
|
|
||||||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
|
||||||
self.inject_layers = min(inject_layers, dec_layers - 1)
|
|
||||||
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
n_in = 1024
|
|
||||||
n_in = n_in + n_attrs # 1024 + 13
|
n_in = n_in + n_attrs # 1024 + 13
|
||||||
for i in range(dec_layers):
|
for i in range(dec_layers):
|
||||||
if i < dec_layers - 1:
|
if i < dec_layers - 1:
|
||||||
|
@ -80,7 +63,16 @@ class Gdec(nn.Cell):
|
||||||
self.repeat = P.Tile()
|
self.repeat = P.Tile()
|
||||||
self.cat = P.Concat(1)
|
self.cat = P.Concat(1)
|
||||||
|
|
||||||
def construct(self, zs, a):
|
def encoder(self, x):
|
||||||
|
"""Encoder construct"""
|
||||||
|
z = x
|
||||||
|
zs = []
|
||||||
|
for layer in self.enc_layers:
|
||||||
|
z = layer(z)
|
||||||
|
zs.append(z)
|
||||||
|
return zs
|
||||||
|
|
||||||
|
def decoder(self, zs, a):
|
||||||
"""Decoder construct"""
|
"""Decoder construct"""
|
||||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||||
multiples = (1, 1, self.f_size, self.f_size)
|
multiples = (1, 1, self.f_size, self.f_size)
|
||||||
|
@ -100,6 +92,16 @@ class Gdec(nn.Cell):
|
||||||
i = i + 1
|
i = i + 1
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
def construct(self, x, a=None, mode="enc-dec"):
|
||||||
|
result = None
|
||||||
|
if mode == "enc-dec":
|
||||||
|
out = self.encoder(x)
|
||||||
|
result = self.decoder(out, a)
|
||||||
|
if mode == "enc":
|
||||||
|
result = self.encoder(x)
|
||||||
|
if mode == "dec":
|
||||||
|
result = self.decoder(x, a)
|
||||||
|
return result
|
||||||
|
|
||||||
class Dis(nn.Cell):
|
class Dis(nn.Cell):
|
||||||
"""Discriminator"""
|
"""Discriminator"""
|
||||||
|
|
|
@ -116,8 +116,7 @@ class TrainOneStepCellGen(nn.Cell):
|
||||||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
self.optimizer(grads)
|
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
|
||||||
return loss, gf_loss, gc_loss, gr_loss
|
|
||||||
|
|
||||||
|
|
||||||
class TrainOneStepCellDis(nn.Cell):
|
class TrainOneStepCellDis(nn.Cell):
|
||||||
|
@ -153,5 +152,4 @@ class TrainOneStepCellDis(nn.Cell):
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
self.optimizer(grads)
|
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp
|
||||||
return loss, d_real_loss, d_fake_loss, dc_loss, df_gp
|
|
||||||
|
|
|
@ -90,10 +90,9 @@ class WGANGPGradientPenalty(nn.Cell):
|
||||||
class GenLoss(nn.Cell):
|
class GenLoss(nn.Cell):
|
||||||
"""Define total Generator loss"""
|
"""Define total Generator loss"""
|
||||||
|
|
||||||
def __init__(self, args, encoder, decoder, discriminator):
|
def __init__(self, args, generator, discriminator):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.generator = generator
|
||||||
self.decoder = decoder
|
|
||||||
self.discriminator = discriminator
|
self.discriminator = discriminator
|
||||||
|
|
||||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||||
|
@ -108,9 +107,9 @@ class GenLoss(nn.Cell):
|
||||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||||
"""Get generator loss"""
|
"""Get generator loss"""
|
||||||
# generate
|
# generate
|
||||||
zs_a = self.encoder(img_a)
|
zs_a = self.generator(img_a, mode="enc")
|
||||||
img_fake = self.decoder(zs_a, att_b_)
|
img_fake = self.generator(zs_a, att_b_, mode="dec")
|
||||||
img_recon = self.decoder(zs_a, att_a_)
|
img_recon = self.generator(zs_a, att_a_, mode="dec")
|
||||||
|
|
||||||
# discriminate
|
# discriminate
|
||||||
d_fake, dc_fake = self.discriminator(img_fake)
|
d_fake, dc_fake = self.discriminator(img_fake)
|
||||||
|
@ -128,10 +127,9 @@ class GenLoss(nn.Cell):
|
||||||
class DisLoss(nn.Cell):
|
class DisLoss(nn.Cell):
|
||||||
"""Define total discriminator loss"""
|
"""Define total discriminator loss"""
|
||||||
|
|
||||||
def __init__(self, args, encoder, decoder, discriminator):
|
def __init__(self, args, generator, discriminator):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.generator = generator
|
||||||
self.decoder = decoder
|
|
||||||
self.discriminator = discriminator
|
self.discriminator = discriminator
|
||||||
|
|
||||||
self.cyc_loss = P.ReduceMean()
|
self.cyc_loss = P.ReduceMean()
|
||||||
|
@ -146,8 +144,7 @@ class DisLoss(nn.Cell):
|
||||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||||
"""Get discriminator loss"""
|
"""Get discriminator loss"""
|
||||||
# generate
|
# generate
|
||||||
z = self.encoder(img_a)
|
img_fake = self.generator(img_a, att_b_, mode="enc-dec")
|
||||||
img_fake = self.decoder(z, att_b_)
|
|
||||||
|
|
||||||
# discriminate
|
# discriminate
|
||||||
d_real, dc_real = self.discriminator(img_a)
|
d_real, dc_real = self.discriminator(img_a)
|
||||||
|
|
|
@ -60,16 +60,13 @@ class DistributedSampler:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name):
|
def resume_generator(args, generator, gen_ckpt_name):
|
||||||
"""Restore the trained generator and discriminator"""
|
"""Restore the trained generator"""
|
||||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||||
encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name)
|
generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name)
|
||||||
decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name)
|
param_generator = load_checkpoint(generator_path, generator)
|
||||||
|
|
||||||
param_encoder = load_checkpoint(encoder_path, encoder)
|
return param_generator
|
||||||
param_decoder = load_checkpoint(decoder_path, decoder)
|
|
||||||
|
|
||||||
return param_encoder, param_decoder
|
|
||||||
|
|
||||||
def resume_discriminator(args, discriminator, dis_ckpt_name):
|
def resume_discriminator(args, discriminator, dis_ckpt_name):
|
||||||
"""Restore the trained discriminator"""
|
"""Restore the trained discriminator"""
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================sss
|
# ============================================================================
|
||||||
"""Entry point for training AttGAN network"""
|
"""Entry point for training AttGAN network"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -32,12 +32,12 @@ from mindspore.context import ParallelMode
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
||||||
from mindspore.train.serialization import load_param_into_net
|
from mindspore.train.serialization import load_param_into_net
|
||||||
|
|
||||||
from src.attgan import Genc, Gdec, Dis
|
from src.attgan import Gen, Dis
|
||||||
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
|
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
|
||||||
from src.data import data_loader
|
from src.data import data_loader
|
||||||
from src.helpers import Progressbar
|
from src.helpers import Progressbar
|
||||||
from src.loss import GenLoss, DisLoss
|
from src.loss import GenLoss, DisLoss
|
||||||
from src.utils import resume_model, resume_discriminator
|
from src.utils import resume_generator, resume_discriminator
|
||||||
|
|
||||||
attrs_default = [
|
attrs_default = [
|
||||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||||
|
@ -94,8 +94,7 @@ def parse(arg=None):
|
||||||
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
|
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
|
||||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||||
parser.add_argument('--resume_model', action='store_true')
|
parser.add_argument('--resume_model', action='store_true')
|
||||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
parser.add_argument('--gen_ckpt_name', type=str, default='')
|
||||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
|
||||||
parser.add_argument('--dis_ckpt_name', type=str, default='')
|
parser.add_argument('--dis_ckpt_name', type=str, default='')
|
||||||
|
|
||||||
return parser.parse_args(arg)
|
return parser.parse_args(arg)
|
||||||
|
@ -150,32 +149,28 @@ if __name__ == '__main__':
|
||||||
print('Training images:', train_length)
|
print('Training images:', train_length)
|
||||||
|
|
||||||
# Define network
|
# Define network
|
||||||
genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train')
|
gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm,
|
||||||
gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers,
|
args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train')
|
||||||
args.inject_layers, args.img_size, mode='train')
|
|
||||||
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
|
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
|
||||||
args.dis_layers, args.img_size, mode='train')
|
args.dis_layers, args.img_size, mode='train')
|
||||||
|
|
||||||
# Initialize network
|
# Initialize network
|
||||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
init_weights(gen, 'KaimingUniform', math.sqrt(5))
|
||||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
|
||||||
init_weights(dis, 'KaimingUniform', math.sqrt(5))
|
init_weights(dis, 'KaimingUniform', math.sqrt(5))
|
||||||
|
|
||||||
# Resume from checkpoint
|
# Resume from checkpoint
|
||||||
if args.resume_model:
|
if args.resume_model:
|
||||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
para_gen = resume_generator(args, gen, args.gen_ckpt_name)
|
||||||
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
|
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
|
||||||
load_param_into_net(genc, para_genc)
|
load_param_into_net(gen, para_gen)
|
||||||
load_param_into_net(gdec, para_gdec)
|
|
||||||
load_param_into_net(dis, para_dis)
|
load_param_into_net(dis, para_dis)
|
||||||
|
|
||||||
# Define network with loss
|
# Define network with loss
|
||||||
G_loss_cell = GenLoss(args, genc, gdec, dis)
|
G_loss_cell = GenLoss(args, gen, dis)
|
||||||
D_loss_cell = DisLoss(args, genc, gdec, dis)
|
D_loss_cell = DisLoss(args, gen, dis)
|
||||||
|
|
||||||
# Define Optimizer
|
# Define Optimizer
|
||||||
G_trainable_params = genc.trainable_params() + gdec.trainable_params()
|
optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||||
optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
|
||||||
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||||
|
|
||||||
# Define One Step Train
|
# Define One Step Train
|
||||||
|
@ -193,21 +188,14 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
|
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
|
||||||
ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder')
|
ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator')
|
||||||
ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder')
|
|
||||||
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
|
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
|
||||||
|
|
||||||
cb_params_genc = _InternalCallbackParam()
|
cb_params_gen = _InternalCallbackParam()
|
||||||
cb_params_genc.train_network = genc
|
cb_params_gen.train_network = gen
|
||||||
cb_params_genc.cur_epoch_num = 0
|
cb_params_gen.cur_epoch_num = 0
|
||||||
genc_run_context = RunContext(cb_params_genc)
|
gen_run_context = RunContext(cb_params_gen)
|
||||||
ckpt_cb_genc.begin(genc_run_context)
|
ckpt_cb_gen.begin(gen_run_context)
|
||||||
|
|
||||||
cb_params_gdec = _InternalCallbackParam()
|
|
||||||
cb_params_gdec.train_network = gdec
|
|
||||||
cb_params_gdec.cur_epoch_num = 0
|
|
||||||
gdec_run_context = RunContext(cb_params_gdec)
|
|
||||||
ckpt_cb_gdec.begin(gdec_run_context)
|
|
||||||
|
|
||||||
cb_params_dis = _InternalCallbackParam()
|
cb_params_dis = _InternalCallbackParam()
|
||||||
cb_params_dis.train_network = dis
|
cb_params_dis.train_network = dis
|
||||||
|
@ -243,16 +231,12 @@ if __name__ == '__main__':
|
||||||
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
|
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
|
||||||
|
|
||||||
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
|
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
|
||||||
cb_params_genc.cur_epoch_num = epoch + 1
|
cb_params_gen.cur_epoch_num = epoch + 1
|
||||||
cb_params_gdec.cur_epoch_num = epoch + 1
|
|
||||||
cb_params_dis.cur_epoch_num = epoch + 1
|
cb_params_dis.cur_epoch_num = epoch + 1
|
||||||
cb_params_genc.cur_step_num = it + 1
|
cb_params_gen.cur_step_num = it + 1
|
||||||
cb_params_gdec.cur_step_num = it + 1
|
|
||||||
cb_params_dis.cur_step_num = it + 1
|
cb_params_dis.cur_step_num = it + 1
|
||||||
cb_params_genc.batch_num = it + 2
|
cb_params_gen.batch_num = it + 2
|
||||||
cb_params_gdec.batch_num = it + 2
|
|
||||||
cb_params_dis.batch_num = it + 2
|
cb_params_dis.batch_num = it + 2
|
||||||
ckpt_cb_genc.step_end(genc_run_context)
|
ckpt_cb_gen.step_end(gen_run_context)
|
||||||
ckpt_cb_gdec.step_end(gdec_run_context)
|
|
||||||
ckpt_cb_dis.step_end(dis_run_context)
|
ckpt_cb_dis.step_end(dis_run_context)
|
||||||
it += 1
|
it += 1
|
||||||
|
|
Loading…
Reference in New Issue