!18949 [线上贡献]黄金赛段Single-Path-NAS网络精度性能调优提交PR

Merge pull request !18949 from ZJUTER0126/single-path-nas
This commit is contained in:
i-robot 2021-09-08 02:47:47 +00:00 committed by Gitee
commit 00ac7b855c
25 changed files with 2315 additions and 0 deletions

View File

@ -0,0 +1,244 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [single-path-nas描述](#single-path-nas描述)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [导出过程](#导出过程)
- [导出](#导出)
- [推理过程](#推理过程)
- [推理](#推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [ImageNet-1k上的single-path-nas](#imagenet-1k上的single-path-nas)
- [推理性能](#推理性能)
- [ImageNet-1k上的single-path-nas](#imagenet-1k上的single-path-nas-1)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# single-path-nas描述
single-path-nas的作者用一个7x7的大卷积来代表3x3、5x5和7x7的三种卷积把外边一圈mask清零掉就变成了3x3或5x5这个大的卷积成为superkernel于是整个网络只有一种卷积看起来是一个直筒结构。搜索空间是基于block的直筒结构跟ProxylessNAS和FBNet一样都采用了Inverted Bottleneck 作为cell, 层数跟MobileNetV2都是22层。每层只有两个参数 expansion rate, kernel size是需要搜索的其他都已固定比如22层中每层的filter number固定死了跟FBNet一样跟MobileNetV2比略有变化。论文中的kernel size和FBNet、 ProxylessNAS一样只有3x3和5x5两种没有用上7x7。论文中的expansion ratio也只有3和6两种选择。kernel size 和 expansion ratio都只有2中选择论文选择用Lightnn这篇论文中的手法把离散选择用连续的光滑函数来表示阈值用group Lasso term。本论文用了跟ProxylessNAS一样的手法来表达skip connection, 用一个zero layer表示。
摘自https://zhuanlan.zhihu.com/p/63605721
# 数据集
使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小共1000个类、224*224彩色图像
- 训练集共1,281,167张图像
- 测试集共50,000张图像
- 数据格式JPEG
- 注数据在dataset.py中处理。
- 下载数据集,目录结构如下:
```text
└─dataset
├─train # 训练数据集
└─val # 评估数据集
```
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html) 的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```bash
# 运行训练示例
python train.py --device_id=0 > train.log 2>&1 &
# 运行分布式训练示例
bash ./scripts/run_train.sh [RANK_TABLE_FILE] imagenet
# 运行评估示例
python eval.py --checkpoint_path ./ckpt_0 > ./eval.log 2>&1 &
# 运行推理示例
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README_CN.md // Single-Path-NAS相关说明
├── scripts
│ ├──run_train.sh // 分布式到Ascend的shell脚本
│ ├──run_eval.sh // 测试脚本
│ ├──run_infer_310.sh // 310推理脚本
├── src
│ ├──lr_scheduler // 学习率相关文件夹包含学习率变化策略的py文件
│ ├──dataset.py // 创建数据集
│ ├──CrossEntropySmooth.py // 损失函数相关
│ ├──spnasnet.py // Single-Path-NAS网络架构
│ ├──config.py // 参数配置
│ ├──utils.py // spnasnet.py的自定义网络模块
├── train.py // 训练和测试文件
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置single-path-nas和ImageNet-1k数据集。
```python
'name':'imagenet' # 数据集
'pre_trained':'False' # 是否基于预训练模型训练
'num_classes':1000 # 数据集类数
'lr_init':0.26 # 初始学习率单卡训练时设置为0.26八卡并行训练时设置为1.5
'batch_size':128 # 训练批次大小
'epoch_size':180 # 总计训练epoch数
'momentum':0.9 # 动量
'weight_decay':1e-5 # 权重衰减值
'image_height':224 # 输入到模型的图像高度
'image_width':224 # 输入到模型的图像宽度
'data_path':'/data/ILSVRC2012_train/' # 训练数据集的绝对全路径
'val_data_path':'/data/ILSVRC2012_val/' # 评估数据集的绝对全路径
'device_target':'Ascend' # 运行设备
'device_id':0 # 用于训练或评估数据集的设备ID使用run_train.sh进行分布式训练时可以忽略。
'keep_checkpoint_max':40 # 最多保存80个ckpt模型文件
'checkpoint_path':None # checkpoint文件保存的绝对全路径
```
更多配置细节请参考脚本`config.py`。
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
python train.py --device_id=0 > train.log 2>&1 &
```
上述python命令将在后台运行可以通过生成的train.log文件查看结果。
### 分布式训练
- Ascend处理器环境运行
```bash
bash ./scripts/run_train.sh [RANK_TABLE_FILE] imagenet
```
上述shell脚本将在后台运行分布训练。
## 评估过程
### 评估
- 在Ascend环境运行时评估ImageNet-1k数据集
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
```bash
python eval.py --checkpoint_path ./ckpt_0 > ./eval.log 2>&1 &
OR
bash ./scripts/run_eval.sh
```
## 导出过程
### 导出
```shell
python export.py --ckpt_file [CKPT_FILE]
```
## 推理过程
### 推理
在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出air模型只能在昇腾910环境上导出。以下展示了使用mindir模型执行推理的示例。
- 在昇腾310上使用ImageNet-1k数据集进行推理
推理的结果保存在scripts目录下在acc.log日志文件中可以找到类似以下的结果。
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
Total data: 50000, top1 accuracy: 0.74214, top5 accuracy: 0.91652.
```
# 模型描述
## 性能
### 评估性能
#### ImageNet-1k上的single-path-nas
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | single-path-nas |
| 资源 | Ascend 910 |
| 上传日期 | 2021-06-27 |
| MindSpore版本 | 1.2.0 |
| 数据集 | ImageNet-1k Train共1,281,167张图像 |
| 训练参数 | epoch=180, batch_size=128, lr_init=0.26单卡为0.26,八卡为1.5 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 分类准确率 | 八卡top1:74.21%,top5:91.712% |
| 速度 | 单卡:毫秒/步八卡87.173毫秒/步 |
### 推理性能
#### ImageNet-1k上的single-path-nas
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | single-path-nas |
| 资源 | Ascend 310 |
| 上传日期 | 2021-06-27 |
| MindSpore版本 | 1.2.0 |
| 数据集 | ImageNet-1k Val共50,000张图像 |
| 分类准确率 | top1:74.214%,top5:91.652% |
| 速度 | Average time 7.67324 ms of infer_count 50000|
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
std::vector<std::string> GetAllFiles(std::string dir_name);
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
#endif

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(MindSporeCxxTestcase[CXX])
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT}/../)
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main main.cc utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "minddata/dataset/include/vision_ascend.h"
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"
#include "inc/utils.h"
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::TensorTransform;
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
auto all_files = GetAllInputData(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = all_files.size();
// Define transform
std::vector<int32_t> crop_paras = {224};
std::vector<int32_t> resize_paras = {256};
std::vector<float> mean = {0.485 * 255, 0.456 * 255, 0.406 * 255};
std::vector<float> std = {0.229 * 255, 0.224 * 255, 0.225 * 255};
auto decode = Decode();
auto resize = Resize(resize_paras);
auto centercrop = CenterCrop(crop_paras);
auto normalize = Normalize(mean, std);
auto hwc2chw = HWC2CHW();
mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
for (size_t i = 0; i < size; ++i) {
for (size_t j = 0; j < all_files[i].size(); ++j) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i][j] <<std::endl;
auto imgDvpp = std::make_shared<MSTensor>();
SingleOp(ReadFileToTensor(all_files[i][j]), imgDvpp.get());
inputs.emplace_back(imgDvpp->Name(), imgDvpp->DataType(), imgDvpp->Shape(),
imgDvpp->Data().get(), imgDvpp->DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i][j] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(all_files[i][j], outputs);
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,185 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
std::vector<std::vector<std::string>> ret;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
struct dirent *filename;
/* read all the files in the dir ~ */
std::vector<std::string> sub_dirs;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
// get rid of "." and ".."
if (d_name == "." || d_name == ".." || d_name.empty()) {
continue;
}
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
struct stat s;
lstat(dir_path.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
continue;
}
sub_dirs.emplace_back(dir_path);
}
std::sort(sub_dirs.begin(), sub_dirs.end());
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
[](const std::string &d) { return GetAllFiles(d); });
return ret;
}
std::vector<std::string> GetAllFiles(std::string dir_name) {
struct dirent *filename;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
continue;
}
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
return res;
}
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create_imagenet2012_label"""
import os
import json
import argparse
parser = argparse.ArgumentParser(description="Single-Path-NAS imagenet2012 label")
parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.")
args = parser.parse_args()
def create_label(file_path):
"""
create label
"""
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
dirs = os.listdir(file_path)
file_list = []
for file in dirs:
file_list.append(file)
file_list = sorted(file_list)
total = 0
img_label = {}
for i, file_dir in enumerate(file_list):
files = os.listdir(os.path.join(file_path, file_dir))
for f in files:
img_label[f] = i
total += len(files)
with open("imagenet_label.json", "w+") as label:
json.dump(img_label, label)
print("[INFO] Completed! Total {} data.".format(total))
if __name__ == '__main__':
create_label(args.img_path)

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Process the test set with the .ckpt model in turn.
"""
import argparse
import os
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common import set_seed
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import src.spnasnet as spnasnet
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
set_seed(1)
parser = argparse.ArgumentParser(description='single-path-nas')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet',],
help='dataset name.')
parser.add_argument('--checkpoint_path', type=str, default='./ckpt_0', help='Checkpoint file path or dir path')
parser.add_argument('--device_id', type=int, default=None, help='device id of Ascend. (Default: None)')
args_opt = parser.parse_args()
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss_ = self.ce(logit, label)
return loss_
if __name__ == '__main__':
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = spnasnet.spnasnet(num_classes=cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else:
raise ValueError("dataset is not support.")
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if device_target == "Ascend":
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
if os.path.isfile(args_opt.checkpoint_path) and args_opt.checkpoint_path.endswith('.ckpt'):
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
acc = model.eval(dataset)
print(f"model {args_opt.checkpoint_path}'s accuracy is {acc}")
elif os.path.isdir(args_opt.checkpoint_path):
file_list = os.listdir(args_opt.checkpoint_path)
for filename in file_list:
de_path = os.path.join(args_opt.checkpoint_path, filename)
if de_path.endswith('.ckpt'):
param_dict = load_checkpoint(de_path)
load_param_into_net(net, param_dict)
net.set_train(False)
acc = model.eval(dataset)
print(f"model {de_path}'s accuracy is {acc}")
else:
raise ValueError("args_opt.checkpoint_path must be a checkpoint file or dir contains checkpoint(s)")

View File

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx or mindir model#################
python export.py
"""
import argparse
import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
import src.spnasnet as spnasnet
from src.config import imagenet_cfg
parser = argparse.ArgumentParser(description='single-path-nas export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="single-path-nas", help="output file name.")
parser.add_argument('--width', type=int, default=224, help='input width')
parser.add_argument('--height', type=int, default=224, help='input height')
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend",], help="device target(default: Ascend)")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
else:
raise ValueError("Unsupported platform.")
if __name__ == '__main__':
net = spnasnet.spnasnet(num_classes=imagenet_cfg.num_classes)
assert args.ckpt_file is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import json
import argparse
import numpy as np
from src.config import imagenet_cfg
batch_size = 1
parser = argparse.ArgumentParser(description="Single-Path-NAS inference")
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
args = parser.parse_args()
def get_result(result_path, label_path):
"""
get result
"""
files = os.listdir(result_path)
with open(label_path, "r") as label:
labels = json.load(label)
top1 = 0
top5 = 0
total_data = len(files)
for file in files:
img_ids_name = file.split('_0.')[0]
data_path = os.path.join(result_path, img_ids_name + "_0.bin")
result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, imagenet_cfg.num_classes)
for batch in range(batch_size):
predict = np.argsort(-result[batch], axis=-1)
if labels[img_ids_name+".JPEG"] == predict[0]:
top1 += 1
if labels[img_ids_name+".JPEG"] in predict[:5]:
top5 += 1
print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.")
if __name__ == '__main__':
get_result(args.result_path, args.label_path)

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
dataset_type='imagenet'
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
cd ..
done

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_eval.sh checkpoint_path_dir/checkpoint_path_file"
exit 1
fi
if [ ! -d $1 ] && [ ! -f $1 ]
then
echo "error: checkpoint_path=$1 is neither a directory nor a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
echo "start evaluation for device $DEVICE_ID"
python eval.py --checkpoint_path=$1 > ./eval.log 2>&1 &

View File

@ -0,0 +1,99 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer/src/ || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/src/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
}
function cal_acc()
{
python3.7 ../create_imagenet2012_label.py --img_path=$data_path
python3.7 ../postprocess.py --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log &
}
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi

View File

@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 0 ]
then
echo "Usage: sh run_train.sh"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
dataset_type='imagenet'
ulimit -u unlimited
export DEVICE_ID=0
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
python train.py --device_id=$DEVICE_ID --dataset_name=$dataset_type> log 2>&1 &

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 1.5, # 1p:0.26 8p:1.5
'batch_size': 128,
'epoch_size': 180,
'momentum': 0.9,
'weight_decay': 1e-5,
'image_height': 224,
'image_width': 224,
'data_path': '/data/ILSVRC2012_train/',
'val_data_path': '/data/ILSVRC2012_val/',
'device_target': 'Ascend',
'device_id': 0,
'keep_checkpoint_max': 40,
'checkpoint_path': None,
'onnx_filename': 'single-path-nas',
'air_filename': 'single-path-nas',
# optimizer and lr related
'lr_scheduler': 'cosine_annealing',
'lr_epochs': [30, 60, 90],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150,
'warmup_epochs': 0,
# loss related
'is_dynamic_loss_scale': 1,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
})

View File

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
from src.config import imagenet_cfg
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=True):
"""
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
image_size = imagenet_cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.5, 0.4, 0.3, 0.2),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
transform_label = [C.TypeCast(mstype.int32)]
if training:
data_set = data_set.map(input_columns="image", num_parallel_workers=16, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label)
else:
data_set = data_set.map(input_columns="image", num_parallel_workers=16, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label)
# apply batch operations
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=False)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" warmup cosine annealing lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup step lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma ** milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

View File

@ -0,0 +1,294 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Single-Path NASNet for ImageNet-1K, implemented in Mindspore.
Original paper: 'Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours,'
https://arxiv.org/abs/1904.02877.
"""
import mindspore.nn as nn
import mindspore.ops as ops
from src.utils import conv1x1_block, conv3x3_block, dwconv3x3_block, dwconv5x5_block
class SPNASUnit(nn.Cell):
"""
Single-Path NASNet unit.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int
Strides of the second convolution layer.
use_kernel3 : bool
Whether to use 3x3 (instead of 5x5) kernel.
exp_factor : int
Expansion factor for each unit.
use_skip : bool, default True
Whether to use skip connection.
activation : str, default 'relu'
Activation function or name of activation function.
"""
def __init__(self,
in_channels,
out_channels,
stride,
use_kernel3,
exp_factor,
use_skip=True,
activation="relu"):
super(SPNASUnit, self).__init__()
self.residual = (in_channels == out_channels) and (stride == 1) and use_skip
self.use_exp_conv = exp_factor > 1
mid_channels = exp_factor * in_channels
if self.use_exp_conv:
self.exp_conv = conv1x1_block(
in_channels=in_channels,
out_channels=mid_channels,
activation=activation)
if use_kernel3:
self.conv1 = dwconv3x3_block(
in_channels=mid_channels,
out_channels=mid_channels,
stride=stride,
activation=activation)
else:
self.conv1 = dwconv5x5_block(
in_channels=mid_channels,
out_channels=mid_channels,
stride=stride,
activation=activation)
self.conv2 = conv1x1_block(
in_channels=mid_channels,
out_channels=out_channels,
activation=None)
if self.residual:
self.add = ops.Add()
def construct(self, x):
"""
Args:
x: Tensor of shape :math:`(N, in_channels, W_{in}, H_{in})
Returns:
y: Tensor of shape :math:`(N, out_channels, W_{in}, H_{in})
"""
identity = x
if self.use_exp_conv:
y = self.exp_conv(x)
y = self.conv1(y)
y = self.conv2(y)
else:
y = self.conv1(x)
y = self.conv2(y)
if self.residual:
y = self.add(y, identity)
return y
class SPNASInitBlock(nn.Cell):
"""
Single-Path NASNet specific initial block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
mid_channels : int
Number of middle channels.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels):
super(SPNASInitBlock, self).__init__()
self.conv1 = conv3x3_block(
in_channels=in_channels,
out_channels=mid_channels,
stride=2)
self.conv2 = SPNASUnit(
in_channels=mid_channels,
out_channels=out_channels,
stride=1,
use_kernel3=True,
exp_factor=1,
use_skip=False)
def construct(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class SPNASFinalBlock(nn.Cell):
"""
Single-Path NASNet specific final block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
mid_channels : int
Number of middle channels.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels):
super(SPNASFinalBlock, self).__init__()
self.conv1 = SPNASUnit(
in_channels=in_channels,
out_channels=mid_channels,
stride=1,
use_kernel3=True,
exp_factor=6,
use_skip=False)
self.conv2 = conv1x1_block(
in_channels=mid_channels,
out_channels=out_channels)
def construct(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class SPNASNet(nn.Cell):
"""
Single-Path NASNet model from 'Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours,'
https://arxiv.org/abs/1904.02877.
Parameters:
----------
channels : list of list of int
Number of output channels for each unit.
init_block_channels : list of 2 int
Number of output channels for the initial unit.
final_block_channels : list of 2 int
Number of output channels for the final block of the feature extractor.
kernels3 : list of list of int/bool
Using 3x3 (instead of 5x5) kernel for each unit.
exp_factors : list of list of int
Expansion factor for each unit.
in_channels : int, default 3
Number of input channels.
in_size : tuple of two ints, default (224, 224)
Spatial size of the expected input image.
num_classes : int, default 1000
Number of classification classes.
"""
def __init__(self,
channels,
init_block_channels,
final_block_channels,
kernels3,
exp_factors,
in_channels=3,
in_size=(224, 224),
num_classes=1000):
super(SPNASNet, self).__init__()
self.in_size = in_size
self.num_classes = num_classes
self.features = nn.SequentialCell()
self.features.append(SPNASInitBlock(
in_channels=in_channels,
out_channels=init_block_channels[1],
mid_channels=init_block_channels[0]))
in_channels = init_block_channels[1]
for i, channels_per_stage in enumerate(channels):
stage = nn.SequentialCell()
for j, out_channels in enumerate(channels_per_stage):
stride = 2 if ((j == 0) and (i != 3)) or ((j == len(channels_per_stage) // 2) and (i == 3)) else 1
use_kernel3 = kernels3[i][j] == 1
exp_factor = exp_factors[i][j]
stage.append(SPNASUnit(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_kernel3=use_kernel3,
exp_factor=exp_factor))
in_channels = out_channels
self.features.append(stage)
self.features.append(SPNASFinalBlock(
in_channels=in_channels,
out_channels=final_block_channels[1],
mid_channels=final_block_channels[0]))
in_channels = final_block_channels[1]
self.features.append(nn.AvgPool2d(
kernel_size=7,
stride=1))
self.output = nn.Dense(
in_channels=in_channels,
out_channels=num_classes,
weight_init='HeUniform')
self.flatten = nn.Flatten()
def construct(self, x):
x = self.features(x)
x = self.flatten(x)
x = self.output(x)
return x
def get_spnasnet(**kwargs):
"""
Create Single-Path NASNet model with specific parameters.
Parameters:
----------
model_name : str or None, default None
Model name for loading pretrained model.
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mindspore/models'
Location for keeping the model parameters.
"""
init_block_channels = [32, 16]
final_block_channels = [320, 1280]
channels = [[24, 24, 24], [40, 40, 40, 40], [80, 80, 80, 80], [96, 96, 96, 96, 192, 192, 192, 192]]
kernels3 = [[1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0]]
exp_factors = [[3, 3, 3], [6, 3, 3, 3], [6, 3, 3, 3], [6, 3, 3, 3, 6, 6, 6, 6]]
net = SPNASNet(
channels=channels,
init_block_channels=init_block_channels,
final_block_channels=final_block_channels,
kernels3=kernels3,
exp_factors=exp_factors,
**kwargs)
return net
def spnasnet(**kwargs):
"""
Single-Path NASNet model from 'Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours,'
https://arxiv.org/abs/1904.02877.
"""
return get_spnasnet(**kwargs)

View File

@ -0,0 +1,390 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
utils.py
"""
from inspect import isfunction
import mindspore.nn as nn
conv_weight_init = 'HeNormal'
class ConvBlock(nn.Cell):
"""
Standard convolution block with Batch normalization and activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
group : int, default 1
Number of group.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
group=1,
has_bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
super(ConvBlock, self).__init__()
self.activate = (activation is not None)
self.use_bn = use_bn
self.use_pad = padding
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode='pad',
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=conv_weight_init)
if self.use_bn:
self.bn = nn.BatchNorm2d(
num_features=out_channels,
momentum=0.9,
eps=bn_eps)
if self.activate:
self.active = get_activation_layer(activation)
def construct(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.activate:
x = self.active(x)
return x
def conv1x1_block(in_channels,
out_channels,
stride=1,
padding=0,
group=1,
has_bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
1x1 version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 0
Padding value for convolution layer.
group : int, default 1
Number of group.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding,
group=group,
has_bias=has_bias,
use_bn=use_bn,
bn_eps=bn_eps,
activation=activation)
def conv3x3_block(in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
group=1,
has_bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
3x3 version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
group : int, default 1
Number of group.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
use_bn=use_bn,
bn_eps=bn_eps,
activation=activation)
def dwconv3x3_block(in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
has_bias=False,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
3x3 depthwise version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
return dwconv_block(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
has_bias=has_bias,
bn_eps=bn_eps,
activation=activation)
def dwconv5x5_block(in_channels,
out_channels,
stride=1,
padding=2,
dilation=1,
has_bias=False,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
5x5 depthwise version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 2
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
return dwconv_block(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=5,
stride=stride,
padding=padding,
dilation=dilation,
has_bias=has_bias,
bn_eps=bn_eps,
activation=activation)
def dwconv_block(in_channels,
out_channels,
kernel_size,
stride=1,
padding=1,
dilation=1,
has_bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
Depthwise version of the standard convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
has_bias : bool, default False
Whether the layer uses a has_bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
return ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
group=out_channels,
has_bias=has_bias,
use_bn=use_bn,
bn_eps=bn_eps,
activation=activation)
class Swish(nn.Cell):
"""
Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941.
"""
def __init__(self):
super(Swish, self).__init__()
self.sigmoid = nn.Sigmoid()
def construct(self, x):
return x * self.sigmoid(x)
class Identity(nn.Cell):
"""
Identity block.
"""
def constructor(self, x):
return x
def get_activation_layer(activation):
"""
Create activation layer from string/function.
Parameters:
----------
activation : function, or str, or nn.Cell
Activation function or name of activation function.
Returns:
-------
nn.Cell
Activation layer.
"""
if isfunction(activation):
active = activation()
elif isinstance(activation, str):
if activation == "relu":
active = nn.ReLU()
elif activation == "relu6":
active = nn.ReLU6()
elif activation == "swish":
active = Swish()
elif activation == "hswish":
active = nn.HSwish()
elif activation == "sigmoid":
active = nn.Sigmoid()
elif activation == "hsigmoid":
active = nn.HSigmoid()
elif activation == "identity":
active = Identity()
else:
raise NotImplementedError()
elif isinstance(activation, nn.Cell):
active = activation
else:
return ValueError()
return active

View File

@ -0,0 +1,166 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train spnasnet example########################
python train.py
"""
import argparse
import os
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.model import Model
from src import spnasnet
from src.CrossEntropySmooth import CrossEntropySmooth
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
set_seed(1)
def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet"""
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
if _cfg.lr_scheduler == 'exponential':
_lr = warmup_step_lr(_cfg.lr_init,
_cfg.lr_epochs,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
gamma=_cfg.lr_gamma,
)
elif _cfg.lr_scheduler == 'cosine_annealing':
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
_cfg.T_max,
_cfg.eta_min)
else:
raise NotImplementedError(_cfg.lr_scheduler)
return _lr
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Single-Path-NAS Training')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet',],
help='dataset name.')
parser.add_argument('--filter_prefix', type=str, default='huawei', help='filter_prefix name.')
parser.add_argument('--device_id', type=int, default=None, help='device id of Ascend. (Default: None)')
args_opt = parser.parse_args()
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
else:
raise ValueError("Unsupported dataset.")
# set context
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, enable_graph_kernel=True)
device_num = int(os.environ.get("DEVICE_NUM", 1))
rank = 0
if device_target == "Ascend":
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
rank = get_rank()
else:
raise ValueError("Unsupported platform.")
if args_opt.dataset_name == "imagenet":
dataset = create_dataset_imagenet(cfg.data_path, 1)
else:
raise ValueError("Unsupported dataset.")
batch_num = dataset.get_dataset_size()
net = spnasnet.get_spnasnet(num_classes=cfg.num_classes)
net.update_parameters_name(args_opt.filter_prefix)
loss_scale_manager = None
if args_opt.dataset_name == 'imagenet':
lr = lr_steps_imagenet(cfg, batch_num)
def get_param_groups(network):
""" get param groups """
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
if cfg.is_dynamic_loss_scale:
cfg.loss_scale = 1
opt = Momentum(params=net.get_parameters(),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'top_1_accuracy', 'top_5_accuracy', 'loss'},
amp_level="O3", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 1, keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
ckpoint_cb = ModelCheckpoint(prefix="train_spnasnet_" + args_opt.dataset_name, directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")