This commit is contained in:
ziquan 2021-05-28 10:55:48 +08:00
parent f5a23ddf26
commit 535bf23754
23 changed files with 3389 additions and 0 deletions

View File

@ -0,0 +1,292 @@
# EAST for Ascend
- [EAST Description](#EAST-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [EAST Description](#contents)
EAST is an efficient and accurate neural network architecture for scene text detection pipeline. The method is divided into two stages: the fully convolutional network stage and the network management system fusion stage. FCN directly generates the text area, excluding redundant and time-consuming intermediate steps. This idea was proposed in the paper "EAST: An Efficient and Accurate Scene Text Detector.", published in 2017.
[Paper](https://arxiv.org/pdf/1704.03155.pdf) Xinyu Zhou, Cong Yao, He Wen, Yuzhi Wang, Shuchang Zhou, Weiran He, and Jiajun Liang Megvii Technology Inc., Beijing, China, Published in CVPR 2017.
# [Model architecture](#contents)
The network structure can be decomposed into three parts: feature extraction, feature merging and output layer.Use VGG, Resnet50 and other networks in the feature extraction layer to obtain feature map,In the feature merging part, the author actually borrowed the idea of U-net to obtain different levels of information,finally, score map and geometry map are obtained in the output layer part.
# [Dataset](#contents)
Dataset used [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4&com=downloads)
- Dataset: ICDAR 2015: Focused Scene Text
- Train: 88.5MB, 1000 images
- Test:43.3MB, 500 images
# [Features](#contents)
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```shell
.
└─east
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
├─run_distribute.sh # launch distributed training with ascend platform(8p)
└─run_eval.sh # launch evaluating with ascend platform
├─src
├─dataset.py # data proprocessing
├─lr_schedule.py # learning rate scheduler
├─east.py # network definition
└─utils.py # some functions which is commonly used
└─distributed_sampler.py # distributed train
└─initializer.py # init
└─logger.py # logger output
├─eval.py # eval net
└─train.py # train net
```
## [Training process](#contents)
### Usage
- Ascend:
```bash
# distribute training example(8p)
sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]
# standalone training
sh run_standalone_train_ascend.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]
# evaluation:
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
```
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
>
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
>
> The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.
### Launch
```bash
# training example
shell:
Ascend:
# distribute training example(8p)
sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]
# standalone training
sh run_standalone_train_ascend.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`
```python
(8p)
...
epoch: 397 step: 1, loss is 0.2616188
epoch: 397 step: 2, loss is 0.38392675
epoch: 397 step: 3, loss is 0.21342245
epoch: 397 step: 4, loss is 0.29853413
epoch: 397 step: 5, loss is 0.2697169
epoch time: 4432.678 ms, per step time: 886.536 ms
epoch: 398 step: 1, loss is 0.32656515
epoch: 398 step: 2, loss is 0.28596723
epoch: 398 step: 3, loss is 0.24983373
epoch: 398 step: 4, loss is 0.29556546
epoch: 398 step: 5, loss is 0.28608245
epoch time: 5230.462 ms, per step time: 1046.092 ms
epoch: 399 step: 1, loss is 0.24444203
epoch: 399 step: 2, loss is 0.24407807
epoch: 399 step: 3, loss is 0.29774582
epoch: 399 step: 4, loss is 0.2569809
epoch: 399 step: 5, loss is 0.25168353
epoch time: 2595.220 ms, per step time: 519.044 ms
epoch: 400 step: 1, loss is 0.21435773
epoch: 400 step: 2, loss is 0.2563093
epoch: 400 step: 3, loss is 0.23374572
epoch: 400 step: 4, loss is 0.457117
epoch: 400 step: 5, loss is 0.28918257
epoch time: 4661.479 ms, per step time: 932.296 ms
epoch: 401 step: 1, loss is 0.26602226
epoch: 401 step: 2, loss is 0.267757
epoch: 401 step: 3, loss is 0.27752787
epoch: 401 step: 4, loss is 0.28883433
epoch: 401 step: 5, loss is 0.20567583
epoch time: 4297.705 ms, per step time: 859.541 ms
...
(1p)
...
epoch time: 20190.564 ms, per step time: 492.453 ms
epoch: 23 step: 1, loss is 1.4938335
epoch: 23 step: 2, loss is 1.7320133
epoch: 23 step: 3, loss is 1.3432003
epoch: 23 step: 4, loss is 1.375334
epoch: 23 step: 5, loss is 1.2183237
epoch: 23 step: 6, loss is 1.152751
epoch: 23 step: 7, loss is 1.1234403
epoch: 23 step: 8, loss is 1.1597326
epoch: 23 step: 9, loss is 1.390804
epoch: 23 step: 10, loss is 1.2011471
epoch: 23 step: 11, loss is 1.7939932
epoch: 23 step: 12, loss is 1.7997816
epoch: 23 step: 13, loss is 1.4836912
epoch: 23 step: 14, loss is 1.3689598
epoch: 23 step: 15, loss is 1.3506227
epoch: 23 step: 16, loss is 2.132399
epoch: 23 step: 17, loss is 1.4153867
epoch: 23 step: 18, loss is 1.351174
epoch: 23 step: 19, loss is 1.9559281
epoch: 23 step: 20, loss is 1.317142
epoch: 23 step: 21, loss is 1.4965435
epoch: 23 step: 22, loss is 1.2664857
epoch: 23 step: 23, loss is 1.7235017
epoch: 23 step: 24, loss is 1.4537313
epoch: 23 step: 25, loss is 1.7973338
epoch: 23 step: 26, loss is 1.583169
epoch: 23 step: 27, loss is 1.5295832
epoch: 23 step: 28, loss is 2.0665898
epoch: 23 step: 29, loss is 1.3507215
epoch: 23 step: 30, loss is 1.2847648
epoch: 23 step: 31, loss is 1.5181551
epoch: 23 step: 32, loss is 1.4159863
epoch: 23 step: 33, loss is 1.4176369
epoch: 23 step: 34, loss is 1.4142565
epoch: 23 step: 35, loss is 1.3644646
epoch: 23 step: 36, loss is 1.1788905
epoch: 23 step: 37, loss is 1.4377214
epoch: 23 step: 38, loss is 1.108615
epoch: 23 step: 39, loss is 1.2742603
epoch: 23 step: 40, loss is 1.3961313
epoch: 23 step: 41, loss is 1.3044286
...
```
## [Eval process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```bash
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
```
### Launch
- A fast Locality-Aware NMS in C++ provided by the paper's author.(g++/gcc version 6.0 + will be ok), you can click [here](https://github.com/argman/EAST) get it.
- You can download [evaluation tool](https://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1) before evaluate . rename the tool as **evaluate** and make directory like following:
```shell
├─lnms # lnms tool
├─evaluate
└─gt.zip # test ground Truth
└─rrc_evaluation_funcs_1_1.py # evaluate Tool from icdar2015
└─script.py # evaluate Tool from icdar2015
├─eval.py # eval net
```
- The evaluation scripts are from [ICDAR Offline evaluation](http://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1) and have been modified to run successfully with Python 3.7.1.
- Change the `evaluate/gt.zip` if you test on other datasets.
- Modify the parameters in `eval.py` and run:
```bash
# eval example
shell:
Ascend:
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
```
> checkpoint can be produced in training process.
### Result
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
```python
Calculated {"precision": 0.8329088130412634, "recall": 0.7871930669234473, "hmean": 0.8094059405940593, "AP": 0}
```
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Ascend |
| ------------------- | ------------------------------------------------------------ |
| Model Version | EAST |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| uploaded Date | 04/27/2021 |
| MindSpore Version | 1.1.1 |
| Dataset | 1000 images |
| Batch_size | 8 |
| Training Parameters | epoch=600, batch_size=8, lr=0.001 |
| Optimizer | Adam |
| Loss Function | Dice for classification, Iou for bbox regression |
| Loss | ~0.27 |
| Total time (8p) | 1h20m |
| Scripts | [east script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/east) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | EAST |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| Uploaded Date | 12/27/2021 |
| MindSpore Version | 1.1.1 |
| Dataset | 500 images |
| Batch_size | 1 |
| Accuracy | "precision": 0.8329088130412634, "recall": 0.7871930669234473, "hmean": 0.8094059405940593 |
| Total time | 2 min |
| Model for inference | 172.7M (.ckpt file) |
#### Training performance results
| **Ascend** | train performance |
| :--------: | :---------------: |
| 1p | 51.25 img/s |
| **Ascend** | train performance |
| :--------: | :---------------: |
| 8p | 300 img/s |
# [Description of Random Situation](#contents)
We set seed to 1 in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);

View File

@ -0,0 +1,118 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision.h"
#include "include/dataset/execute.h"
#include "../inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(input0_path, ".", "input0 path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
auto input0_files = GetAllFiles(FLAGS_input0_path);
if (input0_files.empty()) {
std::cout << "ERROR: input data empty." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
auto input0 = ReadFileToTensor(input0_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
input0.Data().get(), input0.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input0_files[i], outputs);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "../inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,205 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from PIL import Image, ImageDraw
import numpy as np
import mindspore.ops as P
from mindspore import Tensor
import mindspore.dataset.vision.py_transforms as V
from src.dataset import get_rotate_mat
import lanms
def resize_img(img):
"""resize image to be divisible by 32
"""
w, h = img.size
resize_w = w
resize_h = h
resize_h = resize_h if resize_h % 32 == 0 else int(resize_h / 32) * 32
resize_w = resize_w if resize_w % 32 == 0 else int(resize_w / 32) * 32
img = img.resize((resize_w, resize_h), Image.BILINEAR)
ratio_h = resize_h / h
ratio_w = resize_w / w
return img, ratio_h, ratio_w
def load_pil(img):
"""convert PIL Image to Tensor
"""
img = V.ToTensor()(img)
img = V.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
img = Tensor(img)
img = P.ExpandDims()(img, 0)
return img
def is_valid_poly(res, score_shape, scale):
"""check if the poly in image scope
Input:
res : restored poly in original image
score_shape: score map shape
scale : feature map -> image
Output:
True if valid
"""
cnt = 0
for i in range(res.shape[1]):
if res[0, i] < 0 or res[0, i] >= score_shape[1] * scale or \
res[1, i] < 0 or res[1, i] >= score_shape[0] * scale:
cnt += 1
return cnt <= 1
def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
"""restore polys from feature maps in given positions
Input:
valid_pos : potential text positions <numpy.ndarray, (n,2)>
valid_geo : geometry in valid_pos <numpy.ndarray, (5,n)>
score_shape: shape of score map
scale : image / feature map
Output:
restored polys <numpy.ndarray, (n,8)>, index
"""
polys = []
index = []
valid_pos *= scale
d = valid_geo[:4, :] # 4 x N
angle = valid_geo[4, :] # N,
for i in range(valid_pos.shape[0]):
x = valid_pos[i, 0]
y = valid_pos[i, 1]
y_min = y - d[0, i]
y_max = y + d[1, i]
x_min = x - d[2, i]
x_max = x + d[3, i]
rotate_mat = get_rotate_mat(-angle[i])
temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
coordidates = np.concatenate((temp_x, temp_y), axis=0)
res = np.dot(rotate_mat, coordidates)
res[0, :] += x
res[1, :] += y
if is_valid_poly(res, score_shape, scale):
index.append(i)
polys.append([res[0, 0], res[1, 0], res[0, 1], res[1, 1],
res[0, 2], res[1, 2], res[0, 3], res[1, 3]])
return np.array(polys), index
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
"""get boxes from feature map
Input:
score : score map from model <numpy.ndarray, (1,row,col)>
geo : geo map from model <numpy.ndarray, (5,row,col)>
score_thresh: threshold to segment score map
nms_thresh : threshold in nms
Output:
boxes : final polys <numpy.ndarray, (n,9)>
"""
score = score[0, :, :]
xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
if xy_text.size == 0:
return None
xy_text = xy_text[np.argsort(xy_text[:, 0])]
valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape)
if polys_restored.size == 0:
return None
boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = polys_restored
boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
return boxes
def adjust_ratio(boxes, ratio_w, ratio_h):
"""refine boxes
Input:
boxes : detected polys <numpy.ndarray, (n,9)>
ratio_w: ratio of width
ratio_h: ratio of height
Output:
refined boxes
"""
if boxes is None or boxes.size == 0:
return None
boxes[:, [0, 2, 4, 6]] /= ratio_w
boxes[:, [1, 3, 5, 7]] /= ratio_h
return np.around(boxes)
def detect(img, model):
"""detect text regions of img using model
Input:
img : PIL Image
model : detection model
device: gpu if gpu is available
Output:
detected polys
"""
img, ratio_h, ratio_w = resize_img(img)
score, geo = model(load_pil(img))
score = P.Squeeze(0)(score)
geo = P.Squeeze(0)(geo)
boxes = get_boxes(score.asnumpy(), geo.asnumpy())
return adjust_ratio(boxes, ratio_w, ratio_h)
def plot_boxes(img, boxes):
"""plot boxes on image
"""
if boxes is None:
return img
draw = ImageDraw.Draw(img)
for box in boxes:
draw.polygon([box[0], box[1], box[2], box[3], box[4],
box[5], box[6], box[7]], outline=(0, 255, 0))
return img
def detect_dataset(model, test_img_path, submit_path):
"""detection on whole dataset, save .txt results in submit_path
Input:
model : detection model
device : gpu if gpu is available
test_img_path: dataset path
submit_path : submit result for evaluation
"""
img_files = os.listdir(test_img_path)
img_files = sorted([os.path.join(test_img_path, img_file)
for img_file in img_files])
for i, img_file in enumerate(img_files):
print('evaluating {} image'.format(i), end='\r')
boxes = detect(Image.open(img_file), model)
seq = []
if boxes is not None:
seq.extend([','.join([str(int(b))
for b in box[:-1]]) + '\n' for box in boxes])
with open(os.path.join(submit_path, 'res_' +
os.path.basename(img_file).replace('.jpg', '.txt')), 'w') as f:
f.writelines(seq)

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import shutil
import subprocess
import time
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from detect import detect_dataset
from src.east import EAST
parser = argparse.ArgumentParser('mindspore icdar eval')
# device related
parser.add_argument(
'--device_target',
type=str,
default='Ascend',
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument(
'--device_num',
type=int,
default=5,
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument(
'--test_img_path',
default='/data/icdar2015/Test/image/',
type=str,
help='Train dataset directory.')
parser.add_argument('--checkpoint_path', default='best.ckpt', type=str,
help='The ckpt file of ResNet. Default: "".')
args, _ = parser.parse_known_args()
context.set_context(
mode=context.GRAPH_MODE,
enable_auto_mixed_precision=True,
device_target=args.device_target,
save_graphs=False,
device_id=args.device_num)
def eval_model(name, img_path, submit, save_flag=True):
if os.path.exists(submit):
shutil.rmtree(submit)
os.mkdir(submit)
network = EAST()
param_dict = load_checkpoint(name)
load_param_into_net(network, param_dict)
network.set_train(True)
start_time = time.time()
detect_dataset(network, img_path, submit)
os.chdir(submit)
res = subprocess.getoutput('zip -q submit.zip *.txt')
res = subprocess.getoutput('mv submit.zip ../')
os.chdir('../')
res = subprocess.getoutput(
'python ./evaluate/script.py -g=./evaluate/gt.zip -s=./submit.zip')
print(res)
os.remove('./submit.zip')
print('eval time is {}'.format(time.time() - start_time))
if not save_flag:
shutil.rmtree(submit)
if __name__ == '__main__':
model_name = args.checkpoint_path
test_img_path = args.test_img_path
submit_path = './submit'
eval_model(model_name, test_img_path, submit_path)

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.east import EAST
parser = argparse.ArgumentParser(description='EAST')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--image_height', type=int, default=512,
help='image_height.')
parser.add_argument('--image_width', type=int, default=512,
help='image_width.')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=[
'Ascend',
'GPU',
'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument(
"--ckpt_file",
type=str,
required=True,
help="Checkpoint file path.")
parser.add_argument(
"--file_name",
type=str,
default="alexnet",
help="output file name.")
parser.add_argument(
"--file_format",
type=str,
choices=[
"AIR",
"ONNX",
"MINDIR"],
default="AIR",
help="file format")
args_opt = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
device_target=args_opt.device_target)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)
if __name__ == '__main__':
net = EAST()
param_dict = load_checkpoint(args_opt.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros(
[args_opt.batch_size, 3, args_opt.image_height, args_opt.image_width]), ms.float32)
export(
net,
input_arr,
file_name=args_opt.file_name,
file_format=args_opt.file_format)

View File

@ -0,0 +1,221 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import time
import argparse
import shutil
import math
import subprocess
from PIL import ImageDraw
import numpy as np
import mindspore.ops as P
import lanms
parser = argparse.ArgumentParser(description="east inference")
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
args = parser.parse_args()
def get_rotate_mat(theta):
'''positive theta value means rotate clockwise'''
return np.array([[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]])
def is_valid_poly(res, score_shape, scale):
"""check if the poly in image scope
Input:
res : restored poly in original image
score_shape: score map shape
scale : feature map -> image
Output:
True if valid
"""
cnt = 0
for i in range(res.shape[1]):
if res[0, i] < 0 or res[0, i] >= score_shape[1] * scale or \
res[1, i] < 0 or res[1, i] >= score_shape[0] * scale:
cnt += 1
return cnt <= 1
def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
"""restore polys from feature maps in given positions
Input:
valid_pos : potential text positions <numpy.ndarray, (n,2)>
valid_geo : geometry in valid_pos <numpy.ndarray, (5,n)>
score_shape: shape of score map
scale : image / feature map
Output:
restored polys <numpy.ndarray, (n,8)>, index
"""
polys = []
index = []
valid_pos *= scale
d = valid_geo[:4, :] # 4 x N
angle = valid_geo[4, :] # N,
for i in range(valid_pos.shape[0]):
x = valid_pos[i, 0]
y = valid_pos[i, 1]
y_min = y - d[0, i]
y_max = y + d[1, i]
x_min = x - d[2, i]
x_max = x + d[3, i]
rotate_mat = get_rotate_mat(-angle[i])
temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
coordidates = np.concatenate((temp_x, temp_y), axis=0)
res = np.dot(rotate_mat, coordidates)
res[0, :] += x
res[1, :] += y
if is_valid_poly(res, score_shape, scale):
index.append(i)
polys.append([res[0, 0], res[1, 0], res[0, 1], res[1, 1],
res[0, 2], res[1, 2], res[0, 3], res[1, 3]])
return np.array(polys), index
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
"""get boxes from feature map
Input:
score : score map from model <numpy.ndarray, (1,row,col)>
geo : geo map from model <numpy.ndarray, (5,row,col)>
score_thresh: threshold to segment score map
nms_thresh : threshold in nms
Output:
boxes : final polys <numpy.ndarray, (n,9)>
"""
score = score[0, :, :]
xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
if xy_text.size == 0:
return None
xy_text = xy_text[np.argsort(xy_text[:, 0])]
valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape)
if polys_restored.size == 0:
return None
boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = polys_restored
boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
return boxes
def adjust_ratio(boxes, ratio_w, ratio_h):
"""refine boxes
Input:
boxes : detected polys <numpy.ndarray, (n,9)>
ratio_w: ratio of width
ratio_h: ratio of height
Output:
refined boxes
"""
if boxes is None or boxes.size == 0:
return None
boxes[:, [0, 2, 4, 6]] /= ratio_w
boxes[:, [1, 3, 5, 7]] /= ratio_h
return np.around(boxes)
def detect(img):
"""detect text regions of img using model
Input:
img : PIL Image
model : detection model
device: gpu if gpu is available
Output:
detected polys
"""
img, ratio_h, ratio_w = resize_img(img)
score, geo = model(load_pil(img))
score = P.Squeeze(0)(score)
geo = P.Squeeze(0)(geo)
boxes = get_boxes(score.asnumpy(), geo.asnumpy())
return adjust_ratio(boxes, ratio_w, ratio_h)
def plot_boxes(img, boxes):
"""plot boxes on image
"""
if boxes is None:
return img
draw = ImageDraw.Draw(img)
for box in boxes:
draw.polygon([box[0], box[1], box[2], box[3], box[4],
box[5], box[6], box[7]], outline=(0, 255, 0))
return img
def detect_dataset(result_path, submit_path):
"""detection on whole dataset, save .txt results in submit_path
Input:
model : detection model
device : gpu if gpu is available
test_img_path: dataset path
submit_path : submit result for evaluation
"""
img_files = os.listdir(result_path)
img_files = sorted([os.path.join(result_path, img_file)
for img_file in img_files])
n = len(img_files)
for i in range(0, n, 2):
print('evaluating {} image'.format(i/2), end='\r')
score = np.fromfile(img_files[i], dtype=np.float32).reshape(1, 176, 320)
geo = np.fromfile(img_files[i+1], dtype=np.float32).reshape(5, 176, 320)
boxes = get_boxes(score, geo)
boxes = adjust_ratio(boxes, 1, 0.97777777777)
seq = []
if boxes is not None:
seq.extend([','.join([str(int(b))
for b in box[:-1]]) + '\n' for box in boxes])
with open(os.path.join(submit_path, 'res_' +
os.path.basename(img_files[i]).replace('_0.bin', '.txt')), 'w') as f:
f.writelines(seq)
def eval_model(img_path, submit, save_flag=True):
if os.path.exists(submit):
shutil.rmtree(submit)
os.mkdir(submit)
start_time = time.time()
detect_dataset(img_path, submit)
os.chdir(submit)
res = subprocess.getoutput('zip -q submit.zip *.txt')
res = subprocess.getoutput('mv submit.zip ../')
os.chdir('../')
res = subprocess.getoutput(
'python ./evaluate/script.py -g=./evaluate/gt.zip -s=./submit.zip')
print(res)
os.remove('./submit.zip')
print('eval time is {}'.format(time.time() - start_time))
if not save_flag:
shutil.rmtree(submit)
if __name__ == '__main__':
eval_model(args.result_path, './submit')

View File

@ -0,0 +1,81 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
PRETRAINED_BACKBONE=$(get_real_path $2)
RANK_TABLE_FILE=$(get_real_path $3)
echo $DATASET_PATH
echo $PRETRAINED_BACKBONE
echo $RANK_TABLE_FILE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
if [ ! -f $RANK_TABLE_FILE ]
then
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$RANK_TABLE_FILE
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_FILE
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=1 \
--lr=0.001 \
--max_epoch=600 \
--per_batch_size=8 \
--lr_scheduler=my_lr > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CKPT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CKPT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $CKPT_PATH ]
then
echo "error: CKPT_PATH=$CKPT_PATH is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_ID=$3
export RANK_SIZE=1
rm -rf ./eval_standalone
mkdir ./eval_standalone
cp ../*.py ./eval_standalone
cp -r ../src ./eval_standalone
cp -r ../evaluate ./eval_standalone
cp -r ../lanms ./eval_standalone
cd ./eval_standalone || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python eval.py \
--test_img_path=$DATASET_PATH \
--checkpoint_path=$CKPT_PATH \
--device_num=$DEVICE_ID > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,100 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer/ || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --mindir_path=$model --input0_path=$data_path --device_id=$device_id &> infer.log
}
function cal_acc()
{
cd .. || exit
python3.7 postprocess.py --result_path=./scripts/result_Files &> ./acc.log &
}
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
PRETRAINED_BACKBONE=$(get_real_path $2)
RANK_TABLE_FILE=$(get_real_path $3)
echo $DATASET_PATH
echo $PRETRAINED_BACKBONE
echo $RANK_TABLE_FILE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_ID=$3
export RANK_SIZE=1
rm -rf ./train_standalone
mkdir ./train_standalone
cp ../*.py ./train_standalone
cp -r ../src ./train_standalone
cd ./train_standalone || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--device_id=$DEVICE_ID \
--is_distributed=0 \
--lr=0.001 \
--max_epoch=600 \
--per_batch_size=24 \
--lr_scheduler=my_lr > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,476 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import os
from shapely.geometry import Polygon
import numpy as np
import cv2
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
def cal_distance(x1, y1, x2, y2):
'''calculate the Euclidean distance'''
return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
def move_points(vertices, index1, index2, r, coef):
'''move the two points to shrink edge
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
index1 : offset of point1
index2 : offset of point2
r : [r1, r2, r3, r4] in paper
coef : shrink ratio in paper
Output:
vertices: vertices where one edge has been shinked
'''
index1 = index1 % 4
index2 = index2 % 4
x1_index = index1 * 2 + 0
y1_index = index1 * 2 + 1
x2_index = index2 * 2 + 0
y2_index = index2 * 2 + 1
r1 = r[index1]
r2 = r[index2]
length_x = vertices[x1_index] - vertices[x2_index]
length_y = vertices[y1_index] - vertices[y2_index]
length = cal_distance(
vertices[x1_index],
vertices[y1_index],
vertices[x2_index],
vertices[y2_index])
if length > 1:
ratio = (r1 * coef) / length
vertices[x1_index] += ratio * (-length_x)
vertices[y1_index] += ratio * (-length_y)
ratio = (r2 * coef) / length
vertices[x2_index] += ratio * length_x
vertices[y2_index] += ratio * length_y
return vertices
def shrink_poly(vertices, coef=0.3):
'''shrink the text region
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
coef : shrink ratio in paper
Output:
v : vertices of shrunk text region <numpy.ndarray, (8,)>
'''
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
r1 = min(cal_distance(x1, y1, x2, y2), cal_distance(x1, y1, x4, y4))
r2 = min(cal_distance(x2, y2, x1, y1), cal_distance(x2, y2, x3, y3))
r3 = min(cal_distance(x3, y3, x2, y2), cal_distance(x3, y3, x4, y4))
r4 = min(cal_distance(x4, y4, x1, y1), cal_distance(x4, y4, x3, y3))
r = [r1, r2, r3, r4]
# obtain offset to perform move_points() automatically
if cal_distance(x1, y1, x2, y2) + cal_distance(x3, y3, x4, y4) > \
cal_distance(x2, y2, x3, y3) + cal_distance(x1, y1, x4, y4):
offset = 0 # two longer edges are (x1y1-x2y2) & (x3y3-x4y4)
else:
offset = 1 # two longer edges are (x2y2-x3y3) & (x4y4-x1y1)
v = vertices.copy()
v = move_points(v, 0 + offset, 1 + offset, r, coef)
v = move_points(v, 2 + offset, 3 + offset, r, coef)
v = move_points(v, 1 + offset, 2 + offset, r, coef)
v = move_points(v, 3 + offset, 4 + offset, r, coef)
return v
def get_rotate_mat(theta):
'''positive theta value means rotate clockwise'''
return np.array([[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]])
def rotate_vertices(vertices, theta, anchor=None):
'''rotate vertices around anchor
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
theta : angle in radian measure
anchor : fixed position during rotation
Output:
rotated vertices <numpy.ndarray, (8,)>
'''
v = vertices.reshape((4, 2)).T
if anchor is None:
anchor = v[:, :1]
rotate_mat = get_rotate_mat(theta)
res = np.dot(rotate_mat, v - anchor)
return (res + anchor).T.reshape(-1)
def get_boundary(vertices):
'''get the tight boundary around given vertices
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
Output:
the boundary
'''
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
x_min = min(x1, x2, x3, x4)
x_max = max(x1, x2, x3, x4)
y_min = min(y1, y2, y3, y4)
y_max = max(y1, y2, y3, y4)
return x_min, x_max, y_min, y_max
def cal_error(vertices):
'''default orientation is x1y1 : left-top, x2y2 : right-top, x3y3 : right-bot, x4y4 : left-bot
calculate the difference between the vertices orientation and default orientation
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
Output:
err : difference measure
'''
x_min, x_max, y_min, y_max = get_boundary(vertices)
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
err = cal_distance(x1, y1, x_min, y_min) + \
cal_distance(x2, y2, x_max, y_min) + \
cal_distance(x3, y3, x_max, y_max) + \
cal_distance(x4, y4, x_min, y_max)
return err
def find_min_rect_angle(vertices):
'''find the best angle to rotate poly and obtain min rectangle
Input:
vertices: vertices of text region <numpy.ndarray, (8,)>
Output:
the best angle <radian measure>
'''
angle_interval = 1
angle_list = list(range(-90, 90, angle_interval))
area_list = []
for theta in angle_list:
rotated = rotate_vertices(vertices, theta / 180 * math.pi)
x1, y1, x2, y2, x3, y3, x4, y4 = rotated
temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
(max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
area_list.append(temp_area)
sorted_area_index = sorted(
list(
range(
len(area_list))),
key=lambda k: area_list[k])
min_error = float('inf')
best_index = -1
rank_num = 10
# find the best angle with correct orientation
for index in sorted_area_index[:rank_num]:
rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
temp_error = cal_error(rotated)
if temp_error < min_error:
min_error = temp_error
best_index = index
return angle_list[best_index] / 180 * math.pi
def is_cross_text(start_loc, length, vertices):
'''check if the crop image crosses text regions
Input:
start_loc: left-top position
length : length of crop image
vertices : vertices of text regions <numpy.ndarray, (n,8)>
Output:
True if crop image crosses text region
'''
if vertices.size == 0:
return False
start_w, start_h = start_loc
a = np.array([start_w, start_h, start_w +
length, start_h, start_w +
length, start_h +
length, start_w, start_h +
length]).reshape((4, 2))
p1 = Polygon(a).convex_hull
for vertice in vertices:
p2 = Polygon(vertice.reshape((4, 2))).convex_hull
inter = p1.intersection(p2).area
if 0.01 <= inter / p2.area <= 0.99:
return True
return False
def crop_img(img, vertices, labels, length):
'''crop img patches to obtain batch and augment
Input:
img : PIL Image
vertices : vertices of text regions <numpy.ndarray, (n,8)>
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
length : length of cropped image region
Output:
region : cropped image region
new_vertices: new vertices in cropped region
'''
h, w = img.height, img.width
# confirm the shortest side of image >= length
if h >= w and w < length:
img = img.resize((length, int(h * length / w)), Image.BILINEAR)
elif h < w and h < length:
img = img.resize((int(w * length / h), length), Image.BILINEAR)
ratio_w = img.width / w
ratio_h = img.height / h
assert (ratio_w >= 1 and ratio_h >= 1)
new_vertices = np.zeros(vertices.shape)
if vertices.size > 0:
new_vertices[:, [0, 2, 4, 6]] = vertices[:, [0, 2, 4, 6]] * ratio_w
new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * ratio_h
# find random position
remain_h = img.height - length
remain_w = img.width - length
flag = True
cnt = 0
while flag and cnt < 1000:
cnt += 1
start_w = int(np.random.rand() * remain_w)
start_h = int(np.random.rand() * remain_h)
flag = is_cross_text([start_w, start_h], length,
new_vertices[labels == 1, :])
box = (start_w, start_h, start_w + length, start_h + length)
region = img.crop(box)
if new_vertices.size == 0:
return region, new_vertices
new_vertices[:, [0, 2, 4, 6]] -= start_w
new_vertices[:, [1, 3, 5, 7]] -= start_h
return region, new_vertices
def rotate_all_pixels(rotate_mat, anchor_x, anchor_y, length):
'''get rotated locations of all pixels for next stages
Input:
rotate_mat: rotatation matrix
anchor_x : fixed x position
anchor_y : fixed y position
length : length of image
Output:
rotated_x : rotated x positions <numpy.ndarray, (length,length)>
rotated_y : rotated y positions <numpy.ndarray, (length,length)>
'''
x = np.arange(length)
y = np.arange(length)
x, y = np.meshgrid(x, y)
x_lin = x.reshape((1, x.size))
y_lin = y.reshape((1, x.size))
coord_mat = np.concatenate((x_lin, y_lin), 0)
rotated_coord = np.matmul(rotate_mat.astype(np.float16),
(coord_mat - np.array([[anchor_x],
[anchor_y]])).astype(np.float16)) + np.array([[anchor_x],
[anchor_y]])
rotated_x = rotated_coord[0, :].reshape(x.shape)
rotated_y = rotated_coord[1, :].reshape(y.shape)
return rotated_x, rotated_y
def adjust_height(img, vertices, ratio=0.2):
'''adjust height of image to aug data
Input:
img : PIL Image
vertices : vertices of text regions <numpy.ndarray, (n,8)>
ratio : height changes in [0.8, 1.2]
Output:
img : adjusted PIL Image
new_vertices: adjusted vertices
'''
ratio_h = 1 + ratio * (np.random.rand() * 2 - 1)
old_h = img.height
new_h = int(np.around(old_h * ratio_h))
img = img.resize((img.width, new_h), Image.BILINEAR)
new_vertices = vertices.copy()
if vertices.size > 0:
new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * (new_h / old_h)
return img, new_vertices
def rotate_img(img, vertices, angle_range=10):
'''rotate image [-10, 10] degree to aug data
Input:
img : PIL Image
vertices : vertices of text regions <numpy.ndarray, (n,8)>
angle_range : rotate range
Output:
img : rotated PIL Image
new_vertices: rotated vertices
'''
center_x = (img.width - 1) / 2
center_y = (img.height - 1) / 2
angle = angle_range * (np.random.rand() * 2 - 1)
img = img.rotate(angle, Image.BILINEAR)
new_vertices = np.zeros(vertices.shape)
for i, vertice in enumerate(vertices):
new_vertices[i, :] = rotate_vertices(
vertice, -angle / 180 * math.pi, np.array([[center_x], [center_y]]))
return img, new_vertices
def get_score_geo(img, vertices, labels, scale, length):
'''generate score gt and geometry gt
Input:
img : PIL Image
vertices: vertices of text regions <numpy.ndarray, (n,8)>
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
scale : feature map / image
length : image length
Output:
score gt, geo gt, ignored
'''
score_map = np.zeros(
(int(img.height * scale), int(img.width * scale), 1), np.float32)
geo_map = np.zeros(
(int(img.height * scale), int(img.width * scale), 5), np.float32)
ignored_map = np.zeros(
(int(img.height * scale), int(img.width * scale), 1), np.float32)
index = np.arange(0, length, int(1 / scale))
index_x, index_y = np.meshgrid(index, index)
ignored_polys = []
polys = []
for i, vertice in enumerate(vertices):
if labels[i] == 0:
ignored_polys.append(np.around(scale * vertice.reshape((4, 2))).astype(np.int32))
continue
poly = np.around(scale * shrink_poly(vertice).reshape((4, 2))).astype(np.int32)
polys.append(poly)
temp_mask = np.zeros(score_map.shape[:-1], np.float32)
cv2.fillPoly(temp_mask, [poly], 1)
theta = find_min_rect_angle(vertice)
rotate_mat = get_rotate_mat(theta)
rotated_vertices = rotate_vertices(vertice, theta)
x_min, x_max, y_min, y_max = get_boundary(rotated_vertices)
rotated_x, rotated_y = rotate_all_pixels(rotate_mat, vertice[0], vertice[1], length)
d1 = rotated_y - y_min
d1[d1 < 0] = 0
d2 = y_max - rotated_y
d2[d2 < 0] = 0
d3 = rotated_x - x_min
d3[d3 < 0] = 0
d4 = x_max - rotated_x
d4[d4 < 0] = 0
geo_map[:, :, 0] += d1[index_y, index_x] * temp_mask
geo_map[:, :, 1] += d2[index_y, index_x] * temp_mask
geo_map[:, :, 2] += d3[index_y, index_x] * temp_mask
geo_map[:, :, 3] += d4[index_y, index_x] * temp_mask
geo_map[:, :, 4] += theta * temp_mask
cv2.fillPoly(ignored_map, ignored_polys, 1)
cv2.fillPoly(score_map, polys, 1)
return score_map, geo_map, ignored_map
def extract_vertices(lines):
'''extract vertices info from txt lines
Input:
lines : list of string info
Output:
vertices: vertices of text regions <numpy.ndarray, (n,8)>
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
'''
labels = []
vertices = []
for line in lines:
vertices.append(list(map(int, line.rstrip('\n').lstrip('\ufeff').split(',')[:8])))
label = 0 if '###' in line else 1
labels.append(label)
return np.array(vertices), np.array(labels)
class ICDAREASTDataset:
def __init__(self, img_path, gt_path, scale=0.25, length=512):
super(ICDAREASTDataset, self).__init__()
self.img_files = [os.path.join(
img_path,
img_file) for img_file in sorted(os.listdir(img_path))]
self.gt_files = [
os.path.join(
gt_path,
gt_file) for gt_file in sorted(
os.listdir(gt_path))]
self.scale = scale
self.length = length
def __getitem__(self, index):
with open(self.gt_files[index], 'r') as f:
lines = f.readlines()
vertices, labels = extract_vertices(lines)
img = Image.open(self.img_files[index])
img, vertices = adjust_height(img, vertices)
img, vertices = rotate_img(img, vertices)
img, vertices = crop_img(img, vertices, labels, self.length)
score_map, geo_map, ignored_map = get_score_geo(
img, vertices, labels, self.scale, self.length)
score_map = score_map.transpose(2, 0, 1)
ignored_map = ignored_map.transpose(2, 0, 1)
geo_map = geo_map.transpose(2, 0, 1)
if np.sum(score_map) < 1:
score_map[0, 0, 0] = 1
return img, score_map, geo_map, ignored_map
def __len__(self):
return len(self.img_files)
def create_east_dataset(
img_root,
txt_root,
batch_size,
device_num,
rank,
is_training=True):
east_data = ICDAREASTDataset(img_path=img_root, gt_path=txt_root)
distributed_sampler = DistributedSampler(
len(east_data), device_num, rank, shuffle=True)
trans_list = [CV.RandomColorAdjust(0.5, 0.5, 0.5, 0.25),
CV.Rescale(1 / 255.0, 0),
CV.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
CV.HWC2CHW()]
if is_training:
dataset_column_names = [
"image",
"score_map",
"geo_map",
"training_mask"]
ds = de.GeneratorDataset(
east_data,
column_names=dataset_column_names,
num_parallel_workers=32,
sampler=distributed_sampler)
ds = ds.map(
operations=trans_list,
input_columns=["image"],
num_parallel_workers=8,
python_multiprocessing=True)
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=True)
return ds, len(east_data)

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(
self,
dataset_size,
num_replicas=None,
rank=None,
shuffle=True):
if num_replicas is None:
print(
"***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print(
"***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(
math.ceil(
dataset_size *
1.0 /
self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(
seed=self.epoch).permutation(
self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as
# index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,363 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as mstype
import mindspore.nn as nn
import mindspore.ops as P
def _conv(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
pad_mode='pad'):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
layers += [nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pad_mode=pad_mode,
weight_init=weights,
has_bias=False)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
class VGG16FeatureExtraction(nn.Cell):
"""VGG16FeatureExtraction for deeptext"""
def __init__(self):
super(VGG16FeatureExtraction, self).__init__()
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv1_1 = _conv(
in_channels=3,
out_channels=64,
kernel_size=3,
padding=1)
self.conv1_2 = _conv(
in_channels=64,
out_channels=64,
kernel_size=3,
padding=1)
self.conv2_1 = _conv(
in_channels=64,
out_channels=128,
kernel_size=3,
padding=1)
self.conv2_2 = _conv(
in_channels=128,
out_channels=128,
kernel_size=3,
padding=1)
self.conv3_1 = _conv(
in_channels=128,
out_channels=256,
kernel_size=3,
padding=1)
self.conv3_2 = _conv(
in_channels=256,
out_channels=256,
kernel_size=3,
padding=1)
self.conv3_3 = _conv(
in_channels=256,
out_channels=256,
kernel_size=3,
padding=1)
self.conv4_1 = _conv(
in_channels=256,
out_channels=512,
kernel_size=3,
padding=1)
self.conv4_2 = _conv(
in_channels=512,
out_channels=512,
kernel_size=3,
padding=1)
self.conv4_3 = _conv(
in_channels=512,
out_channels=512,
kernel_size=3,
padding=1)
self.conv5_1 = _conv(
in_channels=512,
out_channels=512,
kernel_size=3,
padding=1)
self.conv5_2 = _conv(
in_channels=512,
out_channels=512,
kernel_size=3,
padding=1)
self.conv5_3 = _conv(
in_channels=512,
out_channels=512,
kernel_size=3,
padding=1)
self.cast = P.Cast()
def construct(self, out):
""" Construction of VGG """
f_0 = out
out = self.cast(out, mstype.float32)
out = self.conv1_1(out)
out = self.relu(out)
out = self.conv1_2(out)
out = self.relu(out)
out = self.max_pool(out)
out = self.conv2_1(out)
out = self.relu(out)
out = self.conv2_2(out)
out = self.relu(out)
out = self.max_pool(out)
f_2 = out
out = self.conv3_1(out)
out = self.relu(out)
out = self.conv3_2(out)
out = self.relu(out)
out = self.conv3_3(out)
out = self.relu(out)
out = self.max_pool(out)
f_3 = out
out = self.conv4_1(out)
out = self.relu(out)
out = self.conv4_2(out)
out = self.relu(out)
out = self.conv4_3(out)
out = self.relu(out)
out = self.max_pool(out)
f_4 = out
out = self.conv5_1(out)
out = self.relu(out)
out = self.conv5_2(out)
out = self.relu(out)
out = self.conv5_3(out)
out = self.relu(out)
out = self.max_pool(out)
f_5 = out
return f_0, f_2, f_3, f_4, f_5
class Merge(nn.Cell):
def __init__(self):
super(Merge, self).__init__()
self.conv1 = nn.Conv2d(1024, 128, 1, has_bias=True)
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
128,
128,
3,
padding=1,
pad_mode='pad',
has_bias=True)
self.bn2 = nn.BatchNorm2d(128)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(384, 64, 1, has_bias=True)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(
64,
64,
3,
padding=1,
pad_mode='pad',
has_bias=True)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(192, 32, 1)
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.ReLU()
self.conv6 = nn.Conv2d(
32,
32,
3,
padding=1,
pad_mode='pad',
has_bias=True)
self.bn6 = nn.BatchNorm2d(32)
self.relu6 = nn.ReLU()
self.conv7 = nn.Conv2d(
32,
32,
3,
padding=1,
pad_mode='pad',
has_bias=True)
self.bn7 = nn.BatchNorm2d(32)
self.relu7 = nn.ReLU()
self.concat = P.Concat(axis=1)
def construct(self, x, f1, f2, f3, f4):
img_hight = P.Shape()(x)[2]
img_width = P.Shape()(x)[3]
out = P.ResizeBilinear((img_hight / 16, img_width / 16), True)(f4)
out = self.concat((out, f3))
out = self.relu1(self.bn1(self.conv1(out)))
out = self.relu2(self.bn2(self.conv2(out)))
out = P.ResizeBilinear((img_hight / 8, img_width / 8), True)(out)
out = self.concat((out, f2))
out = self.relu3(self.bn3(self.conv3(out)))
out = self.relu4(self.bn4(self.conv4(out)))
out = P.ResizeBilinear((img_hight / 4, img_width / 4), True)(out)
out = self.concat((out, f1))
out = self.relu5(self.bn5(self.conv5(out)))
out = self.relu6(self.bn6(self.conv6(out)))
out = self.relu7(self.bn7(self.conv7(out)))
return out
class Output(nn.Cell):
def __init__(self, scope=512):
super(Output, self).__init__()
self.conv1 = nn.Conv2d(32, 1, 1)
self.sigmoid1 = nn.Sigmoid()
self.conv2 = nn.Conv2d(32, 4, 1)
self.sigmoid2 = nn.Sigmoid()
self.conv3 = nn.Conv2d(32, 1, 1)
self.sigmoid3 = nn.Sigmoid()
self.scope = scope
self.concat = P.Concat(axis=1)
self.PI = 3.1415926535898
def construct(self, x):
score = self.sigmoid1(self.conv1(x))
loc = self.sigmoid2(self.conv2(x)) * self.scope
angle = (self.sigmoid3(self.conv3(x)) - 0.5) * self.PI
geo = self.concat((loc, angle))
return score, geo
class EAST(nn.Cell):
def __init__(self):
super(EAST, self).__init__()
self.extractor = VGG16FeatureExtraction()
self.merge = Merge()
self.output = Output()
def construct(self, x_1):
f_0, f_1, f_2, f_3, f_4 = self.extractor(x_1)
x_1 = self.merge(f_0, f_1, f_2, f_3, f_4)
score, geo = self.output(x_1)
return score, geo
class DiceCoefficient(nn.Cell):
def __init__(self):
super(DiceCoefficient, self).__init__()
self.sum = P.ReduceSum()
self.eps = 1e-5
def construct(self, true_cls, pred_cls):
intersection = self.sum(true_cls * pred_cls, ())
union = self.sum(true_cls, ()) + self.sum(pred_cls, ()) + self.eps
loss = 1. - (2 * intersection / union)
return loss
class MyMin(nn.Cell):
def __init__(self):
super(MyMin, self).__init__()
self.abs = P.Abs()
def construct(self, a, b):
return (a + b - self.abs(a - b)) / 2
class EastLossBlock(nn.Cell):
def __init__(self):
super(EastLossBlock, self).__init__()
self.split = P.Split(1, 5)
self.min = MyMin()
self.log = P.Log()
self.cos = P.Cos()
self.mean = P.ReduceMean(keep_dims=False)
self.sum = P.ReduceSum()
self.eps = 1e-5
self.dice = DiceCoefficient()
def construct(
self,
y_true_cls,
y_pred_cls,
y_true_geo,
y_pred_geo,
training_mask):
ans = self.sum(y_true_cls)
classification_loss = self.dice(
y_true_cls, y_pred_cls * (1 - training_mask))
# n * 5 * h * w
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = self.split(y_true_geo)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = self.split(y_pred_geo)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = self.min(d2_gt, d2_pred) + self.min(d4_gt, d4_pred)
h_union = self.min(d1_gt, d1_pred) + self.min(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
iou_loss_map = -self.log((area_intersect + 1.0) /
(area_union + 1.0)) # iou_loss_map
angle_loss_map = 1 - self.cos(theta_pred - theta_gt) # angle_loss_map
angle_loss = self.sum(angle_loss_map * y_true_cls) / ans
iou_loss = self.sum(iou_loss_map * y_true_cls) / ans
geo_loss = 10 * angle_loss + iou_loss
return geo_loss + classification_loss
class EastWithLossCell(nn.Cell):
def __init__(self, network):
super(EastWithLossCell, self).__init__()
self.east_network = network
self.loss = EastLossBlock()
def construct(self, img, true_cls, true_geo, training_mask):
socre, geometry = self.east_network(img)
loss = self.loss(
true_cls,
socre,
true_geo,
geometry,
training_mask)
return loss

View File

@ -0,0 +1,184 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameter init."""
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore.common import initializer as init
from mindspore.common.initializer import Initializer as MeInitializer
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = [
'linear',
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError(
"negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of 'num' and 'arr'."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = np.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
# Calculate uniform bounds from standard deviation
bound = math.sqrt(3.0) * std
return np.random.uniform(-bound, bound, arr.shape)
def _calculate_fan_in_and_fan_out(arr):
"""Calculate fan in and fan out."""
dimensions = len(arr.shape)
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = arr.shape[1]
num_output_fmaps = arr.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class KaimingUniform(MeInitializer):
"""Kaiming uniform initializer."""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, array):
tmp = kaiming_uniform_(array, self.a, self.mode, self.nonlinearity)
_assignment(array, tmp)
def default_recurisive_init(custom_cell):
"""Initialize parameter."""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(
init.initializer(
KaimingUniform(
a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
self.log_fn = ''
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + \
'_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
file = logging.FileHandler(self.log_fn)
file.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
file.setFormatter(formatter)
self.addHandler(file)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('east_vgg', rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,242 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate scheduler."""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate."""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(
lr,
lr_epochs,
steps_per_epoch,
warmup_epochs,
max_epoch,
gamma=0.1):
"""Warmup step learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma ** milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(
lr,
milestones,
steps_per_epoch,
0,
max_epoch,
gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(
lr,
lr_epochs,
steps_per_epoch,
max_epoch,
gamma=gamma)
def warmup_cosine_annealing_lr(
lr,
steps_per_epoch,
warmup_epochs,
max_epoch,
t_max,
eta_min=0):
"""Cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * \
(1. + math.cos(math.pi * last_epoch / t_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_v2(
lr,
steps_per_epoch,
warmup_epochs,
max_epoch,
t_max,
eta_min=0):
"""Cosine annealing learning rate V2."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
last_lr = 0
last_epoch_v1 = 0
t_max_v2 = int(max_epoch * 1 / 3)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
if i < total_steps * 2 / 3:
lr = eta_min + (base_lr - eta_min) * (1. + \
math.cos(math.pi * last_epoch / t_max)) / 2
last_lr = lr
last_epoch_v1 = last_epoch
else:
base_lr = last_lr
last_epoch = last_epoch - last_epoch_v1
lr = eta_min + (base_lr - eta_min) * (1. + \
math.cos(math.pi * last_epoch / t_max_v2)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_sample(
lr,
steps_per_epoch,
warmup_epochs,
max_epoch,
t_max,
eta_min=0):
"""Warmup cosine annealing learning rate."""
start_sample_epoch = 60
step_sample = 2
tobe_sampled_epoch = 60
end_sampled_epoch = start_sample_epoch + step_sample * tobe_sampled_epoch
max_sampled_epoch = max_epoch + tobe_sampled_epoch
t_max = max_sampled_epoch
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_sampled_steps):
last_epoch = i // steps_per_epoch
if last_epoch in range(
start_sample_epoch,
end_sampled_epoch,
step_sample):
continue
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * \
(1. + math.cos(math.pi * last_epoch / t_max)) / 2
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)
def my_lr(max_epoch, steps_per_epoch, per_step=2, gamma=0.1, lr=0.001):
lr_each_step = []
total_steps = steps_per_epoch * max_epoch
n = total_steps // per_step
for i in range(max_epoch * steps_per_epoch):
if i % n == 0 and i != 0:
lr = lr * gamma
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)
def get_lr(args):
"""generate learning rate."""
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_v2(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'my_lr':
lr = my_lr(
args.max_epoch,
args.steps_per_epoch,
args.per_step,
args.lr_gamma,
args.lr)
else:
raise NotImplementedError(args.lr_scheduler)
return lr

View File

@ -0,0 +1,121 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def default_wd_filter(x):
"""default weight decay filter."""
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
return False
if parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not
# include BN
return False
if parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not
# include BN
return False
return True
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not
# include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not
# include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0},
{'params': decay_params}]
class ShapeRecord:
"""Log image shape."""
def __init__(self):
self.shape_record = {
416: 0,
448: 0,
480: 0,
512: 0,
544: 0,
576: 0,
608: 0,
640: 0,
672: 0,
704: 0,
736: 0,
'total': 0
}
def set(self, shape):
if len(shape) > 1:
shape = shape[0]
shape = int(shape)
self.shape_record[shape] += 1
self.shape_record['total'] += 1
def show(self, logger):
for key in self.shape_record:
rate = self.shape_record[key] / float(self.shape_record['total'])
logger.info('shape {}: {:.2f}%'.format(key, rate * 100))

View File

@ -0,0 +1,318 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import datetime
from mindspore.context import ParallelMode
from mindspore.nn.optim.adam import Adam
from mindspore import Tensor, Model
from mindspore import context
from mindspore.communication.management import init
import mindspore as ms
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
from mindspore.profiler.profiling import Profiler
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from src.util import AverageMeter, get_param_groups
from src.east import EAST, EastWithLossCell
from src.logger import get_logger
from src.initializer import default_recurisive_init
from src.dataset import create_east_dataset
from src.lr_scheduler import get_lr
set_seed(1)
parser = argparse.ArgumentParser('mindspore icdar training')
# device related
parser.add_argument(
'--device_target',
type=str,
default='Ascend',
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument(
'--device_id',
type=int,
default=0,
help='device where the code will be implemented. (Default: Ascend)')
# dataset related
parser.add_argument(
'--data_dir',
default='/data/icdar2015/Training/',
type=str,
help='Train dataset directory.')
parser.add_argument(
'--per_batch_size',
default=24,
type=int,
help='Batch size for Training. Default: 24.')
parser.add_argument(
'--outputs_dir',
default='outputs/',
type=str,
help='output dir. Default: outputs/')
# network related
parser.add_argument(
'--pretrained_backbone',
default='/data/vgg/0-150_5004.ckpt',
type=str,
help='The ckpt file of ResNet. Default: "".')
parser.add_argument(
'--resume_east',
default='',
type=str,
help='The ckpt file of EAST, which used to fine tune. Default: ""')
# optimizer and lr related
parser.add_argument(
'--lr_scheduler',
default='my_lr',
type=str,
help='Learning rate scheduler, options: exponential, cosine_annealing. Default: cosine_annealing')
parser.add_argument('--lr', default=0.001, type=float,
help='Learning rate. Default: 0.001')
parser.add_argument('--per_step', default=2, type=float,
help='Learning rate change times. Default: 2')
parser.add_argument(
'--lr_gamma',
type=float,
default=0.1,
help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
parser.add_argument(
'--eta_min',
type=float,
default=0.,
help='Eta_min in cosine_annealing scheduler. Default: 0.')
parser.add_argument(
'--t_max',
type=int,
default=100,
help='T-max in cosine_annealing scheduler. Default: 100')
parser.add_argument('--max_epoch', type=int, default=600,
help='Max epoch num to train the model. Default: 100')
parser.add_argument(
'--warmup_epochs',
default=6,
type=float,
help='Warmup epochs. Default: 6')
parser.add_argument(
'--weight_decay',
type=float,
default=0.0005,
help='Weight decay factor. Default: 0.0005')
# loss related
parser.add_argument('--loss_scale', type=int, default=1,
help='Static loss scale. Default: 64')
parser.add_argument(
'--lr_epochs',
type=str,
default='7,7',
help='Epoch of changing of lr changing, split with ",". Default: 220,250')
# logging related
parser.add_argument('--log_interval', type=int, default=10,
help='Logging interval steps. Default: 100')
parser.add_argument(
'--ckpt_path',
type=str,
default='outputs/',
help='Checkpoint save location. Default: outputs/')
parser.add_argument(
'--ckpt_interval',
type=int,
default=1000,
help='Save checkpoint interval. Default: None')
parser.add_argument(
'--is_save_on_master',
type=int,
default=1,
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')
# distributed related
parser.add_argument(
'--is_distributed',
type=int,
default=0,
help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
parser.add_argument(
'--rank',
type=int,
default=0,
help='Local rank of distributed. Default: 0')
parser.add_argument('--group_size', type=int, default=1,
help='World size of device. Default: 1')
# profiler init
parser.add_argument(
'--need_profiler',
type=int,
default=0,
help='Whether use profiler. 0 for no, 1 for yes. Default: 0')
# modelArts
parser.add_argument(
'--is_modelArts',
type=int,
default=0,
help='Trainning in modelArts or not, 1 for yes, 0 for no. Default: 0')
args, _ = parser.parse_known_args()
args.rank = args.device_id
# init distributed
if args.is_distributed:
if args.device_target == "Ascend":
init()
else:
init("nccl")
args.rank = int(os.getenv('DEVICE_ID'))
args.group_size = int(os.getenv('RANK_SIZE'))
context.set_context(
mode=context.GRAPH_MODE,
enable_auto_mixed_precision=True,
device_target=args.device_target,
save_graphs=False,
device_id=args.rank)
# select for master rank save ckpt or all rank save, compatible for model
# parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
if args.is_modelArts:
import moxing as mox
local_data_url = os.path.join('/cache/data', str(args.rank))
local_ckpt_url = os.path.join('/cache/ckpt', str(args.rank))
local_ckpt_url = os.path.join(local_ckpt_url, 'backbone.ckpt')
mox.file.rename(args.pretrained_backbone, local_ckpt_url)
args.pretrained_backbone = local_ckpt_url
mox.file.copy_parallel(args.data_dir, local_data_url)
args.data_dir = local_data_url
args.outputs_dir = os.path.join('/cache', args.outputs_dir)
args.data_root = os.path.abspath(os.path.join(args.data_dir, 'image'))
args.txt_root = os.path.abspath(os.path.join(args.data_dir, 'groundTruth'))
outputs_dir = os.path.join(args.outputs_dir, str(args.rank))
args.outputs_dir = os.path.join(
args.outputs_dir,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
if __name__ == "__main__":
if args.need_profiler:
profiler = Profiler(
output_path=args.outputs_dir,
is_detail=True,
is_show_op_path=True)
loss_meter = AverageMeter('loss')
context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
degree = 1
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
degree = int(os.getenv('RANK_SIZE'))
context.set_auto_parallel_context(
parallel_mode=parallel_mode,
gradients_mean=True,
device_num=degree)
network = EAST()
# default is kaiming-normal
default_recurisive_init(network)
# load pretrained_backbone
if args.pretrained_backbone:
parm_dict = load_checkpoint(args.pretrained_backbone)
load_param_into_net(network, parm_dict)
args.logger.info('finish load pretrained_backbone')
network = EastWithLossCell(network)
if args.resume_east:
parm_dict = load_checkpoint(args.resume_east)
load_param_into_net(network, parm_dict)
args.logger.info('finish get resume east')
args.logger.info('finish get network')
ds, data_size = create_east_dataset(img_root=args.data_root, txt_root=args.txt_root, batch_size=args.per_batch_size,
device_num=args.group_size, rank=args.rank, is_training=True)
args.logger.info('Finish loading dataset')
args.steps_per_epoch = int(
data_size /
args.per_batch_size /
args.group_size)
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
# get learnning rate
lr = get_lr(args)
opt = Adam(
params=get_param_groups(network),
learning_rate=Tensor(
lr,
ms.float32))
loss_scale = FixedLossScaleManager(1.0, drop_overflow_update=True)
model = Model(network, optimizer=opt, loss_scale_manager=loss_scale)
network.set_train()
# save the network model and parameters for subsequence fine-tuning
config_ck = CheckpointConfig(
save_checkpoint_steps=100,
keep_checkpoint_max=1)
# group layers into an object with training and evaluation features
save_ckpt_path = os.path.join(
args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpoint_cb = ModelCheckpoint(
prefix="checkpoint_east",
directory=save_ckpt_path,
config=config_ck)
callback = []
if args.rank == 0:
callback = [
TimeMonitor(
data_size=data_size),
LossMonitor(),
ckpoint_cb]
save_ckpt_path = os.path.join(
args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
model.train(
args.max_epoch,
ds,
callbacks=callback,
dataset_sink_mode=False)
args.logger.info('==========end training===============')