diff --git a/model_zoo/research/cv/AttGAN/README_CN.md b/model_zoo/research/cv/AttGAN/README_CN.md index 61ae56ea89d..52d9df3924a 100644 --- a/model_zoo/research/cv/AttGAN/README_CN.md +++ b/model_zoo/research/cv/AttGAN/README_CN.md @@ -17,6 +17,8 @@ - [评估](#评估) - [推理过程](#推理过程) - [导出MindIR](#导出MindIR) + - [在Ascend310执行推理](#在Ascend310执行推理) + - [结果](#结果) - [模型描述](#模型描述) - [性能](#性能) - [评估性能](#评估性能) @@ -53,7 +55,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 - 硬件(Ascend) - 使用Ascend来搭建硬件环境。 - 框架 - - [MindSpore](https://www.mindspore.cn/install) + - [MindSpore](https://www.mindspore.cn/install/en) - 如需查看详情,请参见如下资源: - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html) @@ -70,17 +72,17 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 export RANK_SIZE=1 python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt OR - bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba + bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba.txt # 运行分布式训练示例 - bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba + bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt # 运行评估示例 export DEVICE_ID=0 export RANK_SIZE=1 - python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt + python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt OR - bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name + bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom.txt gen_ckpt_name ``` 对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。 @@ -90,7 +92,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在 - `/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将两个检查点文件(Encoder和Decoder)的名称作为参数传入。 + `/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将检查点文件(Generator)的名称作为参数传入。 [注意] 以上路径均应设置为绝对路径 @@ -102,10 +104,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 . └─ cv └─ AttGAN + ├── ascend310_infer # 310推理目录 ├── scripts ├──run_distribute_train.sh # 分布式训练的shell脚本 ├──run_single_train.sh # 单卡训练的shell脚本 - ├──run_eval.sh # 推理脚本 + ├──run_eval.sh # 评估脚本 + ├──run_infer_310.sh # 推理脚本 ├─ src ├─ __init__.py # 初始化文件 ├─ block.py # 基础cell @@ -117,6 +121,9 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 ├─ loss.py # loss计算 ├─ eval.py # 测试脚本 ├─ train.py # 训练脚本 + ├─ export.py # MINDIR模型导出脚本 + ├─ preprocess.py # 310推理预处理脚本 + ├─ postprocess.py # 310推理后处理脚本 └─ README_CN.md # AttGAN的文件描述 ``` @@ -141,7 +148,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 - Ascend处理器环境运行 ```bash - bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba + bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba.txt ``` 上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。 @@ -153,12 +160,12 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 - 在Ascend环境运行时评估自定义数据集 该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来修改属性编辑文件(文件的具体参数参照CelebA数据集及属性编辑文件)。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。 - 评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数) + 评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`gen_ckpt_name`(保存了编码器和解码器的参数) ```bash export DEVICE_ID=0 export RANK_SIZE=1 - python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt + python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --gen_ckpt_name generator-119_84999.ckpt ``` 测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。 @@ -168,7 +175,7 @@ CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据 ### 导出MindIR ```shell -python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT] +python export.py --experiment_name [EXPERIMENT_NAME] --gen_ckpt_name [GENERATOR_CKPT_NAME] --file_format [FILE_FORMAT] ``` `file_format` 必须在 ["AIR", "MINDIR"]中选择。 @@ -176,6 +183,26 @@ python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CK 脚本会在当前目录下生成对应的MINDIR文件。 +### 在Ascend310执行推理 + +在执行推理前,必须通过export脚本导出MINDIR模型。以下命令展示了如何通过命令在Ascend310上对图片进行属性编辑: + +```bash +bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID] +``` + +- `MINDIR_PATH` MINDIR文件的路径 +- `ATTR_FILE_PATH` 属性编辑文件的路径,路径应当为绝对路径 +- `DATA_PATH` 需要进行推理的数据集目录,图像格式应当为jpg +- `NEED_PREPROCESS` 表示属性编辑文件是否需要预处理,可以在y或者n中选择,如果选择y,表示进行预处理(在第一次执行推理时需要对属性编辑文件进行预处理,图片较多的话需要一些时间) +- `DEVICE_ID` 可选,默认值为0. + +[注] 属性编辑文件的格式可以参考celeba数据集中的list_attr_celeba.txt文件,第一行为要推理的图片数目,第二行为要编辑的属性,接下来的是要编辑的图片名称和属性tag。属性编辑文件中的图片数目必须和数据集目录中的图片数相同。 + +### 结果 + +推理结果保存在脚本执行的目录下,属性编辑后的图片保存在`result_Files/`目录下,推理的时间统计结果保存在`time_Result/`目录下。编辑后的图片以`imgName_attrId.jpg`的格式保存,如`182001_1.jpg`表示对名称为182001的第一个属性进行编辑后的结果,是否对该属性进行编辑根据属性编辑文件的内容决定。 + # 模型描述 ## 性能 diff --git a/model_zoo/research/cv/AttGAN/ascend310_infer/CMakeLists.txt b/model_zoo/research/cv/AttGAN/ascend310_infer/CMakeLists.txt new file mode 100644 index 00000000000..ee3c8544734 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/ascend310_infer/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/model_zoo/research/cv/AttGAN/ascend310_infer/build.sh b/model_zoo/research/cv/AttGAN/ascend310_infer/build.sh new file mode 100644 index 00000000000..285514e19f2 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/ascend310_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ -d out ]; then + rm -rf out +fi + +mkdir out +cd out || exit + +if [ -f "Makefile" ]; then + make clean +fi + +cmake .. \ + -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/model_zoo/research/cv/AttGAN/ascend310_infer/inc/utils.h b/model_zoo/research/cv/AttGAN/ascend310_infer/inc/utils.h new file mode 100644 index 00000000000..6e58afd350d --- /dev/null +++ b/model_zoo/research/cv/AttGAN/ascend310_infer/inc/utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +DIR *OpenDir(std::string_view dirName); +void Denorm(std::vector *outputs); +std::string RealPath(std::string_view path); +std::vector GetAllFiles(std::string_view dirName); +std::vector ReadCfgToTensor(const std::string &file, size_t *n_ptr); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs); +#endif diff --git a/model_zoo/research/cv/AttGAN/ascend310_infer/src/main.cc b/model_zoo/research/cv/AttGAN/ascend310_infer/src/main.cc new file mode 100644 index 00000000000..ef687c1f247 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/ascend310_infer/src/main.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/dataset/vision_ascend.h" +#include "include/dataset/execute.h" +#include "include/dataset/vision.h" +#include "inc/utils.h" + +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::dataset::Execute; +using mindspore::dataset::TensorTransform; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::Decode; +using color_rep_type = std::underlying_type::type; + +DEFINE_string(gen_mindir_path, "", "generator mindir path"); +DEFINE_string(dataset_path, "", "dataset path"); +DEFINE_string(attr_file_path, "", "attribute file path"); +DEFINE_int32(device_id, 0, "device id"); +DEFINE_int32(image_height, 128, "image height"); +DEFINE_int32(image_width, 128, "image width"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_gen_mindir_path).empty()) { + std::cout << "Invalid generator mindir" << std::endl; + return 1; + } + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + ascend310->SetBufferOptimizeMode("off_optimize"); + context->MutableDeviceInfo().push_back(ascend310); + + mindspore::Graph gen_graph; + Serialization::Load(FLAGS_gen_mindir_path, ModelType::kMindIR, &gen_graph); + + Model gen_model; + Status gen_ret = gen_model.Build(GraphCell(gen_graph), context); + + if (gen_ret != kSuccess) { + std::cout << "ERROR: Generator build failed." << std::endl; + return 1; + } + + size_t n_attrs = 13; + auto all_cfg = ReadCfgToTensor(FLAGS_attr_file_path, &n_attrs); + auto all_files = GetAllFiles(FLAGS_dataset_path); + std::map costTime_map; + double startTimeMs; + double endTimeMs; + size_t size = all_files.size(); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + std::cout << "Start predict input files:" << all_files[i] << std::endl; + + auto img = std::make_shared(); + std::shared_ptr decode(new Decode()); + std::shared_ptr hwc2chw(new HWC2CHW()); + auto resizeShape = {FLAGS_image_height, FLAGS_image_width}; + std::shared_ptr resize(new Resize(resizeShape)); + std::shared_ptr normalize( + new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5})); + Execute composeDecode({decode, resize, normalize, hwc2chw}); + auto image = ReadFileToTensor(all_files[i]); + composeDecode(image, img.get()); + + for (size_t k = 0; k < n_attrs; ++k) { + gettimeofday(&start, nullptr); + std::vector inputs; + std::vector outputs; + size_t index = i * n_attrs + k; + std::cout << static_cast(img->DataType()) << std::endl; + inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), img->Data().get(), img->DataSize()); + inputs.emplace_back(all_cfg[index].Name(), all_cfg[index].DataType(), all_cfg[index].Shape(), + all_cfg[index].Data().get(), all_cfg[index].DataSize()); + Status gen_model_ret = gen_model.Predict(inputs, &outputs); + if (gen_model_ret != kSuccess) { + std::cout << "Generator inference " << all_files[i] << " failed." << std::endl; + return 1; + } + Denorm(&outputs); + int pos = all_files[i].find('.'); + std::string fileName = all_files[i].substr(0, pos); + WriteResult(fileName + "_" + std::to_string(k) + ".jpg", outputs); + gettimeofday(&end, nullptr); + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTimeMs, endTimeMs)); + } + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/model_zoo/research/cv/AttGAN/ascend310_infer/src/utils.cc b/model_zoo/research/cv/AttGAN/ascend310_infer/src/utils.cc new file mode 100644 index 00000000000..e7dcccc420d --- /dev/null +++ b/model_zoo/research/cv/AttGAN/ascend310_infer/src/utils.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "inc/utils.h" + +#include +#include +#include +#include +#include + +using mindspore::MSTensor; +using mindspore::DataType; + +std::vector GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(std::string(dirName) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + for (auto &f : res) { + std::cout << "image file: " << f << std::endl; + } + return res; +} + +int WriteResult(const std::string& imageFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput; + netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + std::cout << fileName << std::endl; + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE * outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), sizeof(char), outputSize, outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, + {static_cast(size)}, nullptr, size); + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + return buffer; +} + +std::vector split(std::string inputs) { + std::vector line; + std::stringstream stream(inputs); + std::string result; + while ( stream >> result ) { + line.push_back(result); + } + return line; +} + +std::vector ReadCfgToTensor(const std::string &file, size_t *n_ptr) { + std::vector res; + if (file.empty()) { + std::cout << "Pointer file is nullptr." << std::endl; + exit(1); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist." << std::endl; + exit(1); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << " open failed." << std::endl; + exit(1); + } + + std::string n_images; + std::string n_attrs; + getline(ifs, n_images); + getline(ifs, n_attrs); + + auto n_images_ = std::stoi(n_images); + auto n_attrs_ = std::stoi(n_attrs); + *n_ptr = n_attrs_; + std::cout << "Image number is " << n_images << std::endl; + std::cout << "Attribute number is " << n_attrs << std::endl; + + auto all_lines = n_images_ * n_attrs_; + for (auto i = 0; i < all_lines; i++) { + std::string val; + getline(ifs, val); + std::vector val_split = split(val); + void *data = malloc(sizeof(float)*n_attrs_); + float *elements = reinterpret_cast(data); + for (auto j = 0; j < n_attrs_; j++) elements[j] = atof(val_split[j].c_str()); + auto size = sizeof(float) * n_attrs_; + mindspore::MSTensor buffer(file + std::to_string(i), mindspore::DataType::kNumberTypeFloat32, + {static_cast(size)}, nullptr, size); + memcpy(buffer.MutableData(), elements, size); + res.emplace_back(buffer); + } + ifs.close(); + return res; +} + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir; + dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} + +void Denorm(std::vector *outputs) { + for (size_t i = 0; i < outputs->size(); ++i) { + size_t outputSize = (*outputs)[i].DataSize(); + float* netOutput = reinterpret_cast((*outputs)[i].MutableData()); + size_t outputLen = outputSize / sizeof(float); + + for (size_t j = 0; j < outputLen; ++j) { + netOutput[j] = (netOutput[j] + 1) / 2 * 255; + netOutput[j] = (netOutput[j] < 0) ? 0 : netOutput[j]; + netOutput[j] = (netOutput[j] > 255) ? 255 : netOutput[j]; + } + } +} diff --git a/model_zoo/research/cv/AttGAN/eval.py b/model_zoo/research/cv/AttGAN/eval.py index 6b36f5c7e5b..66a663bf489 100644 --- a/model_zoo/research/cv/AttGAN/eval.py +++ b/model_zoo/research/cv/AttGAN/eval.py @@ -28,12 +28,12 @@ import mindspore.dataset as de from mindspore import context, Tensor, ops from mindspore.train.serialization import load_param_into_net -from src.attgan import Genc, Gdec +from src.attgan import Gen from src.cell import init_weights from src.data import check_attribute_conflict from src.data import get_loader, Custom from src.helpers import Progressbar -from src.utils import resume_model, denorm +from src.utils import resume_generator, denorm device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) @@ -45,8 +45,7 @@ def parse(arg=None): parser.add_argument('--experiment_name', dest='experiment_name', required=True) parser.add_argument('--test_int', dest='test_int', type=float, default=1.0) parser.add_argument('--num_test', dest='num_test', type=int) - parser.add_argument('--enc_ckpt_name', type=str, default='') - parser.add_argument('--dec_ckpt_name', type=str, default='') + parser.add_argument('--gen_ckpt_name', type=str, default='') parser.add_argument('--custom_img', action='store_true') parser.add_argument('--custom_data', type=str, default='../data/custom') parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt') @@ -62,8 +61,7 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f: args = json.load(f, object_hook=lambda d: argparse.Namespace(**d)) args.test_int = args_.test_int args.num_test = args_.num_test -args.enc_ckpt_name = args_.enc_ckpt_name -args.dec_ckpt_name = args_.dec_ckpt_name +args.gen_ckpt_name = args_.gen_ckpt_name args.custom_img = args_.custom_img args.custom_data = args_.custom_data args.custom_attr = args_.custom_attr @@ -101,15 +99,15 @@ else: print('Testing images:', min(test_len, args.num_test)) # Model loader -genc = Genc(mode='test') -gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test') +gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm, + args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='test') # Initialize network -init_weights(genc, 'KaimingUniform', math.sqrt(5)) -init_weights(gdec, 'KaimingUniform', math.sqrt(5)) -para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name) -load_param_into_net(genc, para_genc) -load_param_into_net(gdec, para_gdec) +init_weights(gen, 'KaimingUniform', math.sqrt(5)) +para_gen = resume_generator(args, gen, args.gen_ckpt_name) +load_param_into_net(gen, para_gen) + +print("Network initializes successfully.") progressbar = Progressbar() it = 0 @@ -134,8 +132,7 @@ for data in test_dataset_iter: att_b_ = (att_b * 2 - 1) * args.thres_int if i > 0: att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int - a_enc = genc(img_a) - samples.append(gdec(a_enc, att_b_)) + samples.append(gen(img_a, att_b_, mode="enc-dec")) cat = ops.Concat(axis=3) samples = cat(samples).asnumpy() result = denorm(samples) diff --git a/model_zoo/research/cv/AttGAN/export.py b/model_zoo/research/cv/AttGAN/export.py index 84bfb0cdc4f..0f95421780d 100644 --- a/model_zoo/research/cv/AttGAN/export.py +++ b/model_zoo/research/cv/AttGAN/export.py @@ -21,14 +21,13 @@ import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import export, load_param_into_net -from src.utils import resume_model -from src.attgan import Genc, Gdec +from src.utils import resume_generator +from src.attgan import Gen parser = argparse.ArgumentParser(description='Attribute Edit') parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--batch_size", type=int, default=1, help="batch size") -parser.add_argument('--enc_ckpt_name', type=str, default='') -parser.add_argument('--dec_ckpt_name', type=str, default='') +parser.add_argument('--gen_ckpt_name', type=str, default='') parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format') parser.add_argument('--experiment_name', dest='experiment_name', required=True) @@ -39,27 +38,20 @@ with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f: args = json.load(f, object_hook=lambda d: argparse.Namespace(**d)) args.device_id = args_.device_id args.batch_size = args_.batch_size -args.enc_ckpt_name = args_.enc_ckpt_name -args.dec_ckpt_name = args_.dec_ckpt_name +args.gen_ckpt_name = args_.gen_ckpt_name args.file_format = args_.file_format context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) if __name__ == '__main__': - genc = Genc(mode='test') - gdec = Gdec(mode='test') + gen = Gen(mode="test") - para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name) - load_param_into_net(genc, para_genc) - load_param_into_net(gdec, para_gdec) + para_gen = resume_generator(args, gen, args.gen_ckpt_name) + load_param_into_net(gen, para_gen) - enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32)) - dec_array = genc(enc_array) + input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32)) input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32)) - G_enc_file = f"AttGAN_Generator_Encoder" - export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format) - - G_dec_file = f"AttGAN_Generator_Decoder" - export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format) + Gen_file = f"attgan_mindir" + export(gen, *(input_array, input_label), file_name=Gen_file, file_format=args.file_format) diff --git a/model_zoo/research/cv/AttGAN/postprocess.py b/model_zoo/research/cv/AttGAN/postprocess.py new file mode 100644 index 00000000000..d2010c12095 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/postprocess.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process for 310 inference""" +import os +import argparse +import numpy as np +from PIL import Image + +def parse(arg=None): + """Define configuration of postprocess""" + parser = argparse.ArgumentParser() + parser.add_argument('--bin_path', type=str, default='./result_Files/') + parser.add_argument('--target_path', type=str, default='./result_Files/') + return parser.parse_args(arg) + +def load_bin_file(bin_file, shape=None, dtype="float32"): + """Load data from bin file""" + data = np.fromfile(bin_file, dtype=dtype) + if shape: + data = np.reshape(data, shape) + return data + +def save_bin_to_image(data, out_name): + """Save bin file to image arrays""" + image = np.transpose(data, (1, 2, 0)) + im = Image.fromarray(np.uint8(image)) + im.save(out_name) + print("Successfully save image in " + out_name) + +def scan_dir(bin_path): + """Scan directory""" + out = os.listdir(bin_path) + return out + +def postprocess(bin_path): + """Post process bin file""" + file_list = scan_dir(bin_path) + for file in file_list: + data = load_bin_file(bin_path + file, shape=(3, 128, 128), dtype="float32") + pos = file.find(".") + file_name = file[0:pos] + "." + "jpg" + outfile = os.path.join(args.target_path, file_name) + save_bin_to_image(data, outfile) + +if __name__ == "__main__": + + args = parse() + postprocess(args.bin_path) diff --git a/model_zoo/research/cv/AttGAN/preprocess.py b/model_zoo/research/cv/AttGAN/preprocess.py new file mode 100644 index 00000000000..eb91b27f116 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/preprocess.py @@ -0,0 +1,123 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pre process for 310 inference""" +import os +from os.path import join + +import argparse +import numpy as np + +selected_attrs = [ + 'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', + 'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young' +] + +def parse(arg=None): + """Define configuration of preprocess""" + parser = argparse.ArgumentParser() + parser.add_argument('--attrs', dest='attrs', default=selected_attrs, nargs='+', help='attributes to learn') + parser.add_argument('--attrs_path', type=str, default='../data/list_attr_custom.txt') + parser.add_argument('--test_int', dest='test_int', type=float, default=1.0) + parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5) + return parser.parse_args(arg) + +args = parse() +args.n_attrs = len(args.attrs) + +def check_attribute_conflict(att_batch, att_name, att_names): + """Check Attributes""" + def _set(att, att_name): + if att_name in att_names: + att[att_names.index(att_name)] = 0.0 + + att_id = att_names.index(att_name) + for att in att_batch: + if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0: + _set(att, 'Bangs') + elif att_name == 'Bangs' and att[att_id] != 0: + _set(att, 'Bald') + _set(att, 'Receding_Hairline') + elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0: + for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: + if n != att_name: + _set(att, n) + elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0: + for n in ['Straight_Hair', 'Wavy_Hair']: + if n != att_name: + _set(att, n) + elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0: + for n in ['Mustache', 'No_Beard']: + if n != att_name: + _set(att, n) + return att_batch + +def read_cfg_file(attr_path): + """Read configuration from attribute file""" + attr_list = open(attr_path, "r", encoding="utf-8").readlines()[1].split() + atts = [attr_list.index(att) + 1 for att in selected_attrs] + labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int) + attr_number = int(open(attr_path, "r", encoding="utf-8").readlines()[0]) + labels = [labels] if attr_number == 1 else labels[0:] + new_attr = [] + for index in range(attr_number): + att = [np.asarray((labels[index] + 1) // 2)] + new_attr.append(att) + new_attr = np.array(new_attr) + return new_attr, attr_number + +def preprocess_cfg(attrs, numbers): + """Preprocess attribute file""" + new_attr = [] + for index in range(numbers): + attr = attrs[index] + att_b_list = [attr] + for i in range(args.n_attrs): + tmp = attr.copy() + tmp[:, i] = 1 - tmp[:, i] + tmp = check_attribute_conflict(tmp, selected_attrs[i], selected_attrs) + att_b_list.append(tmp) + for i, att_b in enumerate(att_b_list): + att_b_ = (att_b * 2 - 1) * args.thres_int + if i > 0: + att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int + new_attr.append(att_b_) + return new_attr + +def write_cfg_file(attrs, numbers): + """Write attribute file""" + cur_dir = os.getcwd() + print(cur_dir) + path = join(cur_dir, 'attrs.txt') + with open(path, "w") as f: + f.writelines(str(numbers)) + f.writelines("\n") + f.writelines(str(args.n_attrs)) + f.writelines("\n") + counts = numbers * args.n_attrs + for index in range(counts): + attrs_list = attrs[index][0] + new_attrs_list = ["%s" % x for x in attrs_list] + sequence = " ".join(new_attrs_list) + f.writelines(sequence) + f.writelines("\n") + print("Generate cfg file successfully.") + +if __name__ == "__main__": + + if args.attrs_path is None: + print("Path is not correct!") + attributes, n_images = read_cfg_file(args.attrs_path) + new_attrs = preprocess_cfg(attributes, n_images) + write_cfg_file(new_attrs, n_images) diff --git a/model_zoo/research/cv/AttGAN/scripts/run_eval.sh b/model_zoo/research/cv/AttGAN/scripts/run_eval.sh index e10b7f2c74c..484ea1b1482 100644 --- a/model_zoo/research/cv/AttGAN/scripts/run_eval.sh +++ b/model_zoo/research/cv/AttGAN/scripts/run_eval.sh @@ -14,17 +14,16 @@ # limitations under the License. # ============================================================================ -if [ $# != 5 ] +if [ $# != 4 ] then - echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]" + echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [GEN_CKPT_NAME]" exit 1 fi experiment_name=$1 data_path=$2 attr_path=$3 -enc_name=$4 -dec_name=$5 +gen_ckpt_name=$4 cores=`cat /proc/cpuinfo|grep "processor" |wc -l` echo "The number of logical core" $cores @@ -47,5 +46,4 @@ python eval.py \ --custom_data $data_path \ --custom_attr $attr_path \ --custom_img \ ---enc_ckpt_name $enc_name \ ---dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 & +--gen_ckpt_name $gen_ckpt_name > ./scripts/EVAL_LOG/log.txt 2>&1 & diff --git a/model_zoo/research/cv/AttGAN/scripts/run_infer_310.sh b/model_zoo/research/cv/AttGAN/scripts/run_infer_310.sh new file mode 100644 index 00000000000..0968a270bb6 --- /dev/null +++ b/model_zoo/research/cv/AttGAN/scripts/run_infer_310.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 4 || $# -gt 5 ]]; then + echo "Usage: bash run_infer_310.sh [GEN_MINDIR_PATH] [ATTR_FILE_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID] + NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'. + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +gen_model=$(get_real_path $1) +attr_path=$(get_real_path $2) +data_path=$(get_real_path $3) + +if [ "$4" == "y" ] || [ "$4" == "n" ];then + need_preprocess=$4 +else + echo "weather need preprocess or not, it's value must be in [y, n]" + exit 1 +fi + +device_id=0 +if [ $# == 5 ]; then + device_id=$5 +fi + +echo "generator mindir name: "$gen_model +echo "attribute file path: "$attr_path +echo "dataset path: "$data_path +echo "need preprocess: "$need_preprocess +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function preprocess_data() +{ + echo "Start to preprocess attr file..." + python ../preprocess.py --attrs_path=$attr_path --test_int=1.0 --thres_int=0.5 &> preprocess.log + echo "Attribute file generates successfully!" +} + +function compile_app() +{ + echo "Start to compile source code..." + cd ../ascend310_infer || exit + bash build.sh &> build.log + echo "Compile successfully." +} + +function infer() +{ + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + echo "Start to execute inference..." + ../ascend310_infer/out/main --gen_mindir_path=$gen_model --dataset_path=$data_path --attr_file_path="attrs.txt" --device_id=$device_id --image_height=128 --image_width=128 &> infer.log +} + +function postprocess_data() +{ + echo "Start to postprocess image file..." + python ../postprocess.py --bin_path="./result_Files/" --target_path="./result_Files/" +} + +if [ $need_preprocess == "y" ]; then + preprocess_data + if [ $? -ne 0 ]; then + echo "preprocess attrs failed" + exit 1 + fi +fi + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi + +infer +if [ $? -ne 0 ]; then + echo "execute inference failed" + exit 1 +fi + +postprocess_data +if [ $? -ne 0 ]; then + echo "postprocess images failed" + exit 1 +fi diff --git a/model_zoo/research/cv/AttGAN/src/attgan.py b/model_zoo/research/cv/AttGAN/src/attgan.py index 453d64846ec..fc33510fe34 100644 --- a/model_zoo/research/cv/AttGAN/src/attgan.py +++ b/model_zoo/research/cv/AttGAN/src/attgan.py @@ -22,12 +22,15 @@ from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock # Image size 128 x 128 MAX_DIM = 64 * 16 - -class Genc(nn.Cell): - """Generator encoder""" +class Gen(nn.Cell): + """Generator""" def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu", - mode='test'): + dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", + n_attrs=13, shortcut_layers=1, inject_layers=1, img_size=128, mode="test"): super().__init__() + self.shortcut_layers = min(shortcut_layers, dec_layers - 1) + self.inject_layers = min(inject_layers, dec_layers - 1) + self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128 layers = [] n_in = 3 @@ -39,27 +42,7 @@ class Genc(nn.Cell): n_in = n_out self.enc_layers = nn.CellList(layers) - def construct(self, x): - """Encoder construct""" - z = x - zs = [] - for layer in self.enc_layers: - z = layer(z) - zs.append(z) - return zs - - -class Gdec(nn.Cell): - """Generator decoder""" - def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13, - shortcut_layers=1, inject_layers=1, img_size=128, mode='test'): - super().__init__() - self.shortcut_layers = min(shortcut_layers, dec_layers - 1) - self.inject_layers = min(inject_layers, dec_layers - 1) - self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128 - layers = [] - n_in = 1024 n_in = n_in + n_attrs # 1024 + 13 for i in range(dec_layers): if i < dec_layers - 1: @@ -80,7 +63,16 @@ class Gdec(nn.Cell): self.repeat = P.Tile() self.cat = P.Concat(1) - def construct(self, zs, a): + def encoder(self, x): + """Encoder construct""" + z = x + zs = [] + for layer in self.enc_layers: + z = layer(z) + zs.append(z) + return zs + + def decoder(self, zs, a): """Decoder construct""" a_tile = self.view(a, (a.shape[0], -1, 1, 1)) multiples = (1, 1, self.f_size, self.f_size) @@ -100,6 +92,16 @@ class Gdec(nn.Cell): i = i + 1 return z + def construct(self, x, a=None, mode="enc-dec"): + result = None + if mode == "enc-dec": + out = self.encoder(x) + result = self.decoder(out, a) + if mode == "enc": + result = self.encoder(x) + if mode == "dec": + result = self.decoder(x, a) + return result class Dis(nn.Cell): """Discriminator""" diff --git a/model_zoo/research/cv/AttGAN/src/cell.py b/model_zoo/research/cv/AttGAN/src/cell.py index ec8d9a2928d..5271048c6ea 100644 --- a/model_zoo/research/cv/AttGAN/src/cell.py +++ b/model_zoo/research/cv/AttGAN/src/cell.py @@ -116,8 +116,7 @@ class TrainOneStepCellGen(nn.Cell): grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens) if self.reducer_flag: grads = self.grad_reducer(grads) - self.optimizer(grads) - return loss, gf_loss, gc_loss, gr_loss + return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss class TrainOneStepCellDis(nn.Cell): @@ -153,5 +152,4 @@ class TrainOneStepCellDis(nn.Cell): if self.reducer_flag: grads = self.grad_reducer(grads) - self.optimizer(grads) - return loss, d_real_loss, d_fake_loss, dc_loss, df_gp + return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp diff --git a/model_zoo/research/cv/AttGAN/src/loss.py b/model_zoo/research/cv/AttGAN/src/loss.py index 24cccd9c75d..48c237a310b 100644 --- a/model_zoo/research/cv/AttGAN/src/loss.py +++ b/model_zoo/research/cv/AttGAN/src/loss.py @@ -90,10 +90,9 @@ class WGANGPGradientPenalty(nn.Cell): class GenLoss(nn.Cell): """Define total Generator loss""" - def __init__(self, args, encoder, decoder, discriminator): + def __init__(self, args, generator, discriminator): super().__init__() - self.encoder = encoder - self.decoder = decoder + self.generator = generator self.discriminator = discriminator self.lambda_1 = Tensor(args.lambda_1, mstype.float32) @@ -108,9 +107,9 @@ class GenLoss(nn.Cell): def construct(self, img_a, att_a, att_a_, att_b, att_b_): """Get generator loss""" # generate - zs_a = self.encoder(img_a) - img_fake = self.decoder(zs_a, att_b_) - img_recon = self.decoder(zs_a, att_a_) + zs_a = self.generator(img_a, mode="enc") + img_fake = self.generator(zs_a, att_b_, mode="dec") + img_recon = self.generator(zs_a, att_a_, mode="dec") # discriminate d_fake, dc_fake = self.discriminator(img_fake) @@ -128,10 +127,9 @@ class GenLoss(nn.Cell): class DisLoss(nn.Cell): """Define total discriminator loss""" - def __init__(self, args, encoder, decoder, discriminator): + def __init__(self, args, generator, discriminator): super().__init__() - self.encoder = encoder - self.decoder = decoder + self.generator = generator self.discriminator = discriminator self.cyc_loss = P.ReduceMean() @@ -146,8 +144,7 @@ class DisLoss(nn.Cell): def construct(self, img_a, att_a, att_a_, att_b, att_b_): """Get discriminator loss""" # generate - z = self.encoder(img_a) - img_fake = self.decoder(z, att_b_) + img_fake = self.generator(img_a, att_b_, mode="enc-dec") # discriminate d_real, dc_real = self.discriminator(img_a) diff --git a/model_zoo/research/cv/AttGAN/src/utils.py b/model_zoo/research/cv/AttGAN/src/utils.py index df9d6785351..63538507882 100644 --- a/model_zoo/research/cv/AttGAN/src/utils.py +++ b/model_zoo/research/cv/AttGAN/src/utils.py @@ -60,16 +60,13 @@ class DistributedSampler: def __len__(self): return self.num_samples -def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name): - """Restore the trained generator and discriminator""" +def resume_generator(args, generator, gen_ckpt_name): + """Restore the trained generator""" print("Loading the trained models from step {}...".format(args.save_interval)) - encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name) - decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name) + generator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', gen_ckpt_name) + param_generator = load_checkpoint(generator_path, generator) - param_encoder = load_checkpoint(encoder_path, encoder) - param_decoder = load_checkpoint(decoder_path, decoder) - - return param_encoder, param_decoder + return param_generator def resume_discriminator(args, discriminator, dis_ckpt_name): """Restore the trained discriminator""" diff --git a/model_zoo/research/cv/AttGAN/train.py b/model_zoo/research/cv/AttGAN/train.py index df0854514ca..1a3dcbb8f42 100644 --- a/model_zoo/research/cv/AttGAN/train.py +++ b/model_zoo/research/cv/AttGAN/train.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================sss +# ============================================================================ """Entry point for training AttGAN network""" import argparse @@ -32,12 +32,12 @@ from mindspore.context import ParallelMode from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext from mindspore.train.serialization import load_param_into_net -from src.attgan import Genc, Gdec, Dis +from src.attgan import Gen, Dis from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights from src.data import data_loader from src.helpers import Progressbar from src.loss import GenLoss, DisLoss -from src.utils import resume_model, resume_discriminator +from src.utils import resume_generator, resume_discriminator attrs_default = [ 'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', @@ -94,8 +94,7 @@ def parse(arg=None): default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.") parser.add_argument('--resume_model', action='store_true') - parser.add_argument('--enc_ckpt_name', type=str, default='') - parser.add_argument('--dec_ckpt_name', type=str, default='') + parser.add_argument('--gen_ckpt_name', type=str, default='') parser.add_argument('--dis_ckpt_name', type=str, default='') return parser.parse_args(arg) @@ -150,32 +149,28 @@ if __name__ == '__main__': print('Training images:', train_length) # Define network - genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train') - gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers, - args.inject_layers, args.img_size, mode='train') + gen = Gen(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, args.dec_dim, args.dec_layers, args.dec_norm, + args.dec_acti, args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size, mode='train') dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti, args.dis_layers, args.img_size, mode='train') # Initialize network - init_weights(genc, 'KaimingUniform', math.sqrt(5)) - init_weights(gdec, 'KaimingUniform', math.sqrt(5)) + init_weights(gen, 'KaimingUniform', math.sqrt(5)) init_weights(dis, 'KaimingUniform', math.sqrt(5)) # Resume from checkpoint if args.resume_model: - para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name) + para_gen = resume_generator(args, gen, args.gen_ckpt_name) para_dis = resume_discriminator(args, dis, args.dis_ckpt_name) - load_param_into_net(genc, para_genc) - load_param_into_net(gdec, para_gdec) + load_param_into_net(gen, para_gen) load_param_into_net(dis, para_dis) # Define network with loss - G_loss_cell = GenLoss(args, genc, gdec, dis) - D_loss_cell = DisLoss(args, genc, gdec, dis) + G_loss_cell = GenLoss(args, gen, dis) + D_loss_cell = DisLoss(args, gen, dis) # Define Optimizer - G_trainable_params = genc.trainable_params() + gdec.trainable_params() - optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2) + optimizer_G = nn.Adam(params=gen.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2) optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2) # Define One Step Train @@ -193,21 +188,14 @@ if __name__ == '__main__': if rank == 0: local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank)) - ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder') - ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder') + ckpt_cb_gen = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='generator') ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator') - cb_params_genc = _InternalCallbackParam() - cb_params_genc.train_network = genc - cb_params_genc.cur_epoch_num = 0 - genc_run_context = RunContext(cb_params_genc) - ckpt_cb_genc.begin(genc_run_context) - - cb_params_gdec = _InternalCallbackParam() - cb_params_gdec.train_network = gdec - cb_params_gdec.cur_epoch_num = 0 - gdec_run_context = RunContext(cb_params_gdec) - ckpt_cb_gdec.begin(gdec_run_context) + cb_params_gen = _InternalCallbackParam() + cb_params_gen.train_network = gen + cb_params_gen.cur_epoch_num = 0 + gen_run_context = RunContext(cb_params_gen) + ckpt_cb_gen.begin(gen_run_context) cb_params_dis = _InternalCallbackParam() cb_params_dis.train_network = dis @@ -243,16 +231,12 @@ if __name__ == '__main__': gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp) if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0: - cb_params_genc.cur_epoch_num = epoch + 1 - cb_params_gdec.cur_epoch_num = epoch + 1 + cb_params_gen.cur_epoch_num = epoch + 1 cb_params_dis.cur_epoch_num = epoch + 1 - cb_params_genc.cur_step_num = it + 1 - cb_params_gdec.cur_step_num = it + 1 + cb_params_gen.cur_step_num = it + 1 cb_params_dis.cur_step_num = it + 1 - cb_params_genc.batch_num = it + 2 - cb_params_gdec.batch_num = it + 2 + cb_params_gen.batch_num = it + 2 cb_params_dis.batch_num = it + 2 - ckpt_cb_genc.step_end(genc_run_context) - ckpt_cb_gdec.step_end(gdec_run_context) + ckpt_cb_gen.step_end(gen_run_context) ckpt_cb_dis.step_end(dis_run_context) it += 1