forked from mindspore-Ecosystem/mindspore
commit
7faee3bffe
|
@ -0,0 +1,292 @@
|
|||
# EAST for Ascend
|
||||
|
||||
- [EAST Description](#EAST-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [EAST Description](#contents)
|
||||
|
||||
EAST is an efficient and accurate neural network architecture for scene text detection pipeline. The method is divided into two stages: the fully convolutional network stage and the network management system fusion stage. FCN directly generates the text area, excluding redundant and time-consuming intermediate steps. This idea was proposed in the paper "EAST: An Efficient and Accurate Scene Text Detector.", published in 2017.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1704.03155.pdf) Xinyu Zhou, Cong Yao, He Wen, Yuzhi Wang, Shuchang Zhou, Weiran He, and Jiajun Liang Megvii Technology Inc., Beijing, China, Published in CVPR 2017.
|
||||
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The network structure can be decomposed into three parts: feature extraction, feature merging and output layer.Use VGG, Resnet50 and other networks in the feature extraction layer to obtain feature map,In the feature merging part, the author actually borrowed the idea of U-net to obtain different levels of information,finally, score map and geometry map are obtained in the output layer part.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4&com=downloads)
|
||||
|
||||
- Dataset: ICDAR 2015: Focused Scene Text
|
||||
- Train: 88.5MB, 1000 images
|
||||
- Test:43.3MB, 500 images
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─east
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute.sh # launch distributed training with ascend platform(8p)
|
||||
└─run_eval.sh # launch evaluating with ascend platform
|
||||
├─src
|
||||
├─dataset.py # data proprocessing
|
||||
├─lr_schedule.py # learning rate scheduler
|
||||
├─east.py # network definition
|
||||
└─utils.py # some functions which is commonly used
|
||||
└─distributed_sampler.py # distributed train
|
||||
└─initializer.py # init
|
||||
└─logger.py # logger output
|
||||
├─eval.py # eval net
|
||||
└─train.py # train net
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]
|
||||
# evaluation:
|
||||
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
>
|
||||
> The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
# training example
|
||||
shell:
|
||||
Ascend:
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`
|
||||
|
||||
```python
|
||||
(8p)
|
||||
...
|
||||
epoch: 397 step: 1, loss is 0.2616188
|
||||
epoch: 397 step: 2, loss is 0.38392675
|
||||
epoch: 397 step: 3, loss is 0.21342245
|
||||
epoch: 397 step: 4, loss is 0.29853413
|
||||
epoch: 397 step: 5, loss is 0.2697169
|
||||
epoch time: 4432.678 ms, per step time: 886.536 ms
|
||||
epoch: 398 step: 1, loss is 0.32656515
|
||||
epoch: 398 step: 2, loss is 0.28596723
|
||||
epoch: 398 step: 3, loss is 0.24983373
|
||||
epoch: 398 step: 4, loss is 0.29556546
|
||||
epoch: 398 step: 5, loss is 0.28608245
|
||||
epoch time: 5230.462 ms, per step time: 1046.092 ms
|
||||
epoch: 399 step: 1, loss is 0.24444203
|
||||
epoch: 399 step: 2, loss is 0.24407807
|
||||
epoch: 399 step: 3, loss is 0.29774582
|
||||
epoch: 399 step: 4, loss is 0.2569809
|
||||
epoch: 399 step: 5, loss is 0.25168353
|
||||
epoch time: 2595.220 ms, per step time: 519.044 ms
|
||||
epoch: 400 step: 1, loss is 0.21435773
|
||||
epoch: 400 step: 2, loss is 0.2563093
|
||||
epoch: 400 step: 3, loss is 0.23374572
|
||||
epoch: 400 step: 4, loss is 0.457117
|
||||
epoch: 400 step: 5, loss is 0.28918257
|
||||
epoch time: 4661.479 ms, per step time: 932.296 ms
|
||||
epoch: 401 step: 1, loss is 0.26602226
|
||||
epoch: 401 step: 2, loss is 0.267757
|
||||
epoch: 401 step: 3, loss is 0.27752787
|
||||
epoch: 401 step: 4, loss is 0.28883433
|
||||
epoch: 401 step: 5, loss is 0.20567583
|
||||
epoch time: 4297.705 ms, per step time: 859.541 ms
|
||||
...
|
||||
(1p)
|
||||
...
|
||||
epoch time: 20190.564 ms, per step time: 492.453 ms
|
||||
epoch: 23 step: 1, loss is 1.4938335
|
||||
epoch: 23 step: 2, loss is 1.7320133
|
||||
epoch: 23 step: 3, loss is 1.3432003
|
||||
epoch: 23 step: 4, loss is 1.375334
|
||||
epoch: 23 step: 5, loss is 1.2183237
|
||||
epoch: 23 step: 6, loss is 1.152751
|
||||
epoch: 23 step: 7, loss is 1.1234403
|
||||
epoch: 23 step: 8, loss is 1.1597326
|
||||
epoch: 23 step: 9, loss is 1.390804
|
||||
epoch: 23 step: 10, loss is 1.2011471
|
||||
epoch: 23 step: 11, loss is 1.7939932
|
||||
epoch: 23 step: 12, loss is 1.7997816
|
||||
epoch: 23 step: 13, loss is 1.4836912
|
||||
epoch: 23 step: 14, loss is 1.3689598
|
||||
epoch: 23 step: 15, loss is 1.3506227
|
||||
epoch: 23 step: 16, loss is 2.132399
|
||||
epoch: 23 step: 17, loss is 1.4153867
|
||||
epoch: 23 step: 18, loss is 1.351174
|
||||
epoch: 23 step: 19, loss is 1.9559281
|
||||
epoch: 23 step: 20, loss is 1.317142
|
||||
epoch: 23 step: 21, loss is 1.4965435
|
||||
epoch: 23 step: 22, loss is 1.2664857
|
||||
epoch: 23 step: 23, loss is 1.7235017
|
||||
epoch: 23 step: 24, loss is 1.4537313
|
||||
epoch: 23 step: 25, loss is 1.7973338
|
||||
epoch: 23 step: 26, loss is 1.583169
|
||||
epoch: 23 step: 27, loss is 1.5295832
|
||||
epoch: 23 step: 28, loss is 2.0665898
|
||||
epoch: 23 step: 29, loss is 1.3507215
|
||||
epoch: 23 step: 30, loss is 1.2847648
|
||||
epoch: 23 step: 31, loss is 1.5181551
|
||||
epoch: 23 step: 32, loss is 1.4159863
|
||||
epoch: 23 step: 33, loss is 1.4176369
|
||||
epoch: 23 step: 34, loss is 1.4142565
|
||||
epoch: 23 step: 35, loss is 1.3644646
|
||||
epoch: 23 step: 36, loss is 1.1788905
|
||||
epoch: 23 step: 37, loss is 1.4377214
|
||||
epoch: 23 step: 38, loss is 1.108615
|
||||
epoch: 23 step: 39, loss is 1.2742603
|
||||
epoch: 23 step: 40, loss is 1.3961313
|
||||
epoch: 23 step: 41, loss is 1.3044286
|
||||
...
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
- A fast Locality-Aware NMS in C++ provided by the paper's author.(g++/gcc version 6.0 + will be ok), you can click [here](https://github.com/argman/EAST) get it.
|
||||
- You can download [evaluation tool](https://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1) before evaluate . rename the tool as **evaluate** and make directory like following:
|
||||
|
||||
```shell
|
||||
├─lnms # lnms tool
|
||||
├─evaluate
|
||||
└─gt.zip # test ground Truth
|
||||
└─rrc_evaluation_funcs_1_1.py # evaluate Tool from icdar2015
|
||||
└─script.py # evaluate Tool from icdar2015
|
||||
├─eval.py # eval net
|
||||
```
|
||||
|
||||
- The evaluation scripts are from [ICDAR Offline evaluation](http://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1) and have been modified to run successfully with Python 3.7.1.
|
||||
- Change the `evaluate/gt.zip` if you test on other datasets.
|
||||
- Modify the parameters in `eval.py` and run:
|
||||
|
||||
```bash
|
||||
# eval example
|
||||
shell:
|
||||
Ascend:
|
||||
sh run_eval_ascend.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
|
||||
|
||||
```python
|
||||
Calculated {"precision": 0.8329088130412634, "recall": 0.7871930669234473, "hmean": 0.8094059405940593, "AP": 0}
|
||||
```
|
||||
|
||||
# [Model description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | EAST |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| uploaded Date | 04/27/2021 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 1000 images |
|
||||
| Batch_size | 8 |
|
||||
| Training Parameters | epoch=600, batch_size=8, lr=0.001 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Dice for classification, Iou for bbox regression |
|
||||
| Loss | ~0.27 |
|
||||
| Total time (8p) | 1h20m |
|
||||
| Scripts | [east script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/east) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | EAST |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| Uploaded Date | 12/27/2021 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 500 images |
|
||||
| Batch_size | 1 |
|
||||
| Accuracy | "precision": 0.8329088130412634, "recall": 0.7871930669234473, "hmean": 0.8094059405940593 |
|
||||
| Total time | 2 min |
|
||||
| Model for inference | 172.7M (.ckpt file) |
|
||||
|
||||
#### Training performance results
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 1p | 51.25 img/s |
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 8p | 300 img/s |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set seed to 1 in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "../inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(input0_path, ".", "input0 path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: input data empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = input0_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
|
||||
|
||||
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
input0.Data().get(), input0.DataSize());
|
||||
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(input0_files[i], outputs);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "../inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as P
|
||||
from mindspore import Tensor
|
||||
import mindspore.dataset.vision.py_transforms as V
|
||||
from src.dataset import get_rotate_mat
|
||||
|
||||
import lanms
|
||||
|
||||
|
||||
def resize_img(img):
|
||||
"""resize image to be divisible by 32
|
||||
"""
|
||||
w, h = img.size
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
resize_h = resize_h if resize_h % 32 == 0 else int(resize_h / 32) * 32
|
||||
resize_w = resize_w if resize_w % 32 == 0 else int(resize_w / 32) * 32
|
||||
img = img.resize((resize_w, resize_h), Image.BILINEAR)
|
||||
ratio_h = resize_h / h
|
||||
ratio_w = resize_w / w
|
||||
|
||||
return img, ratio_h, ratio_w
|
||||
|
||||
|
||||
def load_pil(img):
|
||||
"""convert PIL Image to Tensor
|
||||
"""
|
||||
img = V.ToTensor()(img)
|
||||
img = V.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
|
||||
img = Tensor(img)
|
||||
img = P.ExpandDims()(img, 0)
|
||||
return img
|
||||
|
||||
|
||||
def is_valid_poly(res, score_shape, scale):
|
||||
"""check if the poly in image scope
|
||||
Input:
|
||||
res : restored poly in original image
|
||||
score_shape: score map shape
|
||||
scale : feature map -> image
|
||||
Output:
|
||||
True if valid
|
||||
"""
|
||||
cnt = 0
|
||||
for i in range(res.shape[1]):
|
||||
if res[0, i] < 0 or res[0, i] >= score_shape[1] * scale or \
|
||||
res[1, i] < 0 or res[1, i] >= score_shape[0] * scale:
|
||||
cnt += 1
|
||||
return cnt <= 1
|
||||
|
||||
|
||||
def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
|
||||
"""restore polys from feature maps in given positions
|
||||
Input:
|
||||
valid_pos : potential text positions <numpy.ndarray, (n,2)>
|
||||
valid_geo : geometry in valid_pos <numpy.ndarray, (5,n)>
|
||||
score_shape: shape of score map
|
||||
scale : image / feature map
|
||||
Output:
|
||||
restored polys <numpy.ndarray, (n,8)>, index
|
||||
"""
|
||||
polys = []
|
||||
index = []
|
||||
valid_pos *= scale
|
||||
d = valid_geo[:4, :] # 4 x N
|
||||
angle = valid_geo[4, :] # N,
|
||||
|
||||
for i in range(valid_pos.shape[0]):
|
||||
x = valid_pos[i, 0]
|
||||
y = valid_pos[i, 1]
|
||||
y_min = y - d[0, i]
|
||||
y_max = y + d[1, i]
|
||||
x_min = x - d[2, i]
|
||||
x_max = x + d[3, i]
|
||||
rotate_mat = get_rotate_mat(-angle[i])
|
||||
|
||||
temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
|
||||
temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
|
||||
coordidates = np.concatenate((temp_x, temp_y), axis=0)
|
||||
res = np.dot(rotate_mat, coordidates)
|
||||
res[0, :] += x
|
||||
res[1, :] += y
|
||||
|
||||
if is_valid_poly(res, score_shape, scale):
|
||||
index.append(i)
|
||||
polys.append([res[0, 0], res[1, 0], res[0, 1], res[1, 1],
|
||||
res[0, 2], res[1, 2], res[0, 3], res[1, 3]])
|
||||
return np.array(polys), index
|
||||
|
||||
|
||||
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
|
||||
"""get boxes from feature map
|
||||
Input:
|
||||
score : score map from model <numpy.ndarray, (1,row,col)>
|
||||
geo : geo map from model <numpy.ndarray, (5,row,col)>
|
||||
score_thresh: threshold to segment score map
|
||||
nms_thresh : threshold in nms
|
||||
Output:
|
||||
boxes : final polys <numpy.ndarray, (n,9)>
|
||||
"""
|
||||
score = score[0, :, :]
|
||||
xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
|
||||
if xy_text.size == 0:
|
||||
return None
|
||||
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
||||
valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
|
||||
valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
|
||||
polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape)
|
||||
if polys_restored.size == 0:
|
||||
return None
|
||||
|
||||
boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
|
||||
boxes[:, :8] = polys_restored
|
||||
boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
|
||||
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
|
||||
return boxes
|
||||
|
||||
|
||||
def adjust_ratio(boxes, ratio_w, ratio_h):
|
||||
"""refine boxes
|
||||
Input:
|
||||
boxes : detected polys <numpy.ndarray, (n,9)>
|
||||
ratio_w: ratio of width
|
||||
ratio_h: ratio of height
|
||||
Output:
|
||||
refined boxes
|
||||
"""
|
||||
if boxes is None or boxes.size == 0:
|
||||
return None
|
||||
boxes[:, [0, 2, 4, 6]] /= ratio_w
|
||||
boxes[:, [1, 3, 5, 7]] /= ratio_h
|
||||
return np.around(boxes)
|
||||
|
||||
|
||||
def detect(img, model):
|
||||
"""detect text regions of img using model
|
||||
Input:
|
||||
img : PIL Image
|
||||
model : detection model
|
||||
device: gpu if gpu is available
|
||||
Output:
|
||||
detected polys
|
||||
"""
|
||||
img, ratio_h, ratio_w = resize_img(img)
|
||||
score, geo = model(load_pil(img))
|
||||
score = P.Squeeze(0)(score)
|
||||
geo = P.Squeeze(0)(geo)
|
||||
boxes = get_boxes(score.asnumpy(), geo.asnumpy())
|
||||
return adjust_ratio(boxes, ratio_w, ratio_h)
|
||||
|
||||
|
||||
def plot_boxes(img, boxes):
|
||||
"""plot boxes on image
|
||||
"""
|
||||
if boxes is None:
|
||||
return img
|
||||
|
||||
draw = ImageDraw.Draw(img)
|
||||
for box in boxes:
|
||||
draw.polygon([box[0], box[1], box[2], box[3], box[4],
|
||||
box[5], box[6], box[7]], outline=(0, 255, 0))
|
||||
return img
|
||||
|
||||
|
||||
def detect_dataset(model, test_img_path, submit_path):
|
||||
"""detection on whole dataset, save .txt results in submit_path
|
||||
Input:
|
||||
model : detection model
|
||||
device : gpu if gpu is available
|
||||
test_img_path: dataset path
|
||||
submit_path : submit result for evaluation
|
||||
"""
|
||||
img_files = os.listdir(test_img_path)
|
||||
img_files = sorted([os.path.join(test_img_path, img_file)
|
||||
for img_file in img_files])
|
||||
|
||||
for i, img_file in enumerate(img_files):
|
||||
print('evaluating {} image'.format(i), end='\r')
|
||||
boxes = detect(Image.open(img_file), model)
|
||||
seq = []
|
||||
if boxes is not None:
|
||||
seq.extend([','.join([str(int(b))
|
||||
for b in box[:-1]]) + '\n' for box in boxes])
|
||||
with open(os.path.join(submit_path, 'res_' +
|
||||
os.path.basename(img_file).replace('.jpg', '.txt')), 'w') as f:
|
||||
f.writelines(seq)
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from detect import detect_dataset
|
||||
from src.east import EAST
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore icdar eval')
|
||||
|
||||
# device related
|
||||
parser.add_argument(
|
||||
'--device_target',
|
||||
type=str,
|
||||
default='Ascend',
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument(
|
||||
'--device_num',
|
||||
type=int,
|
||||
default=5,
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
|
||||
parser.add_argument(
|
||||
'--test_img_path',
|
||||
default='/data/icdar2015/Test/image/',
|
||||
type=str,
|
||||
help='Train dataset directory.')
|
||||
parser.add_argument('--checkpoint_path', default='best.ckpt', type=str,
|
||||
help='The ckpt file of ResNet. Default: "".')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target,
|
||||
save_graphs=False,
|
||||
device_id=args.device_num)
|
||||
|
||||
|
||||
def eval_model(name, img_path, submit, save_flag=True):
|
||||
if os.path.exists(submit):
|
||||
shutil.rmtree(submit)
|
||||
os.mkdir(submit)
|
||||
network = EAST()
|
||||
param_dict = load_checkpoint(name)
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(True)
|
||||
|
||||
start_time = time.time()
|
||||
detect_dataset(network, img_path, submit)
|
||||
os.chdir(submit)
|
||||
res = subprocess.getoutput('zip -q submit.zip *.txt')
|
||||
res = subprocess.getoutput('mv submit.zip ../')
|
||||
os.chdir('../')
|
||||
res = subprocess.getoutput(
|
||||
'python ./evaluate/script.py -g=./evaluate/gt.zip -s=./submit.zip')
|
||||
print(res)
|
||||
os.remove('./submit.zip')
|
||||
print('eval time is {}'.format(time.time() - start_time))
|
||||
|
||||
if not save_flag:
|
||||
shutil.rmtree(submit)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_name = args.checkpoint_path
|
||||
test_img_path = args.test_img_path
|
||||
submit_path = './submit'
|
||||
eval_model(model_name, test_img_path, submit_path)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.east import EAST
|
||||
|
||||
parser = argparse.ArgumentParser(description='EAST')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--image_height', type=int, default=512,
|
||||
help='image_height.')
|
||||
parser.add_argument('--image_width', type=int, default=512,
|
||||
help='image_width.')
|
||||
parser.add_argument(
|
||||
'--device_target',
|
||||
type=str,
|
||||
default="Ascend",
|
||||
choices=[
|
||||
'Ascend',
|
||||
'GPU',
|
||||
'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument(
|
||||
"--ckpt_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Checkpoint file path.")
|
||||
parser.add_argument(
|
||||
"--file_name",
|
||||
type=str,
|
||||
default="alexnet",
|
||||
help="output file name.")
|
||||
parser.add_argument(
|
||||
"--file_format",
|
||||
type=str,
|
||||
choices=[
|
||||
"AIR",
|
||||
"ONNX",
|
||||
"MINDIR"],
|
||||
default="AIR",
|
||||
help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = EAST()
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros(
|
||||
[args_opt.batch_size, 3, args_opt.image_height, args_opt.image_width]), ms.float32)
|
||||
export(
|
||||
net,
|
||||
input_arr,
|
||||
file_name=args_opt.file_name,
|
||||
file_format=args_opt.file_format)
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
import subprocess
|
||||
from PIL import ImageDraw
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as P
|
||||
|
||||
import lanms
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="east inference")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_rotate_mat(theta):
|
||||
'''positive theta value means rotate clockwise'''
|
||||
return np.array([[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]])
|
||||
|
||||
|
||||
def is_valid_poly(res, score_shape, scale):
|
||||
"""check if the poly in image scope
|
||||
Input:
|
||||
res : restored poly in original image
|
||||
score_shape: score map shape
|
||||
scale : feature map -> image
|
||||
Output:
|
||||
True if valid
|
||||
"""
|
||||
cnt = 0
|
||||
for i in range(res.shape[1]):
|
||||
if res[0, i] < 0 or res[0, i] >= score_shape[1] * scale or \
|
||||
res[1, i] < 0 or res[1, i] >= score_shape[0] * scale:
|
||||
cnt += 1
|
||||
return cnt <= 1
|
||||
|
||||
|
||||
def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
|
||||
"""restore polys from feature maps in given positions
|
||||
Input:
|
||||
valid_pos : potential text positions <numpy.ndarray, (n,2)>
|
||||
valid_geo : geometry in valid_pos <numpy.ndarray, (5,n)>
|
||||
score_shape: shape of score map
|
||||
scale : image / feature map
|
||||
Output:
|
||||
restored polys <numpy.ndarray, (n,8)>, index
|
||||
"""
|
||||
polys = []
|
||||
index = []
|
||||
valid_pos *= scale
|
||||
d = valid_geo[:4, :] # 4 x N
|
||||
angle = valid_geo[4, :] # N,
|
||||
|
||||
for i in range(valid_pos.shape[0]):
|
||||
x = valid_pos[i, 0]
|
||||
y = valid_pos[i, 1]
|
||||
y_min = y - d[0, i]
|
||||
y_max = y + d[1, i]
|
||||
x_min = x - d[2, i]
|
||||
x_max = x + d[3, i]
|
||||
rotate_mat = get_rotate_mat(-angle[i])
|
||||
|
||||
temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
|
||||
temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
|
||||
coordidates = np.concatenate((temp_x, temp_y), axis=0)
|
||||
res = np.dot(rotate_mat, coordidates)
|
||||
res[0, :] += x
|
||||
res[1, :] += y
|
||||
|
||||
if is_valid_poly(res, score_shape, scale):
|
||||
index.append(i)
|
||||
polys.append([res[0, 0], res[1, 0], res[0, 1], res[1, 1],
|
||||
res[0, 2], res[1, 2], res[0, 3], res[1, 3]])
|
||||
return np.array(polys), index
|
||||
|
||||
|
||||
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
|
||||
"""get boxes from feature map
|
||||
Input:
|
||||
score : score map from model <numpy.ndarray, (1,row,col)>
|
||||
geo : geo map from model <numpy.ndarray, (5,row,col)>
|
||||
score_thresh: threshold to segment score map
|
||||
nms_thresh : threshold in nms
|
||||
Output:
|
||||
boxes : final polys <numpy.ndarray, (n,9)>
|
||||
"""
|
||||
score = score[0, :, :]
|
||||
xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
|
||||
if xy_text.size == 0:
|
||||
return None
|
||||
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
||||
valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
|
||||
valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
|
||||
polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape)
|
||||
if polys_restored.size == 0:
|
||||
return None
|
||||
|
||||
boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
|
||||
boxes[:, :8] = polys_restored
|
||||
boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
|
||||
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
|
||||
return boxes
|
||||
|
||||
|
||||
def adjust_ratio(boxes, ratio_w, ratio_h):
|
||||
"""refine boxes
|
||||
Input:
|
||||
boxes : detected polys <numpy.ndarray, (n,9)>
|
||||
ratio_w: ratio of width
|
||||
ratio_h: ratio of height
|
||||
Output:
|
||||
refined boxes
|
||||
"""
|
||||
if boxes is None or boxes.size == 0:
|
||||
return None
|
||||
boxes[:, [0, 2, 4, 6]] /= ratio_w
|
||||
boxes[:, [1, 3, 5, 7]] /= ratio_h
|
||||
return np.around(boxes)
|
||||
|
||||
|
||||
def detect(img):
|
||||
"""detect text regions of img using model
|
||||
Input:
|
||||
img : PIL Image
|
||||
model : detection model
|
||||
device: gpu if gpu is available
|
||||
Output:
|
||||
detected polys
|
||||
"""
|
||||
img, ratio_h, ratio_w = resize_img(img)
|
||||
score, geo = model(load_pil(img))
|
||||
score = P.Squeeze(0)(score)
|
||||
geo = P.Squeeze(0)(geo)
|
||||
boxes = get_boxes(score.asnumpy(), geo.asnumpy())
|
||||
return adjust_ratio(boxes, ratio_w, ratio_h)
|
||||
|
||||
|
||||
def plot_boxes(img, boxes):
|
||||
"""plot boxes on image
|
||||
"""
|
||||
if boxes is None:
|
||||
return img
|
||||
|
||||
draw = ImageDraw.Draw(img)
|
||||
for box in boxes:
|
||||
draw.polygon([box[0], box[1], box[2], box[3], box[4],
|
||||
box[5], box[6], box[7]], outline=(0, 255, 0))
|
||||
return img
|
||||
|
||||
def detect_dataset(result_path, submit_path):
|
||||
"""detection on whole dataset, save .txt results in submit_path
|
||||
Input:
|
||||
model : detection model
|
||||
device : gpu if gpu is available
|
||||
test_img_path: dataset path
|
||||
submit_path : submit result for evaluation
|
||||
"""
|
||||
img_files = os.listdir(result_path)
|
||||
img_files = sorted([os.path.join(result_path, img_file)
|
||||
for img_file in img_files])
|
||||
n = len(img_files)
|
||||
|
||||
for i in range(0, n, 2):
|
||||
print('evaluating {} image'.format(i/2), end='\r')
|
||||
|
||||
score = np.fromfile(img_files[i], dtype=np.float32).reshape(1, 176, 320)
|
||||
geo = np.fromfile(img_files[i+1], dtype=np.float32).reshape(5, 176, 320)
|
||||
|
||||
boxes = get_boxes(score, geo)
|
||||
boxes = adjust_ratio(boxes, 1, 0.97777777777)
|
||||
seq = []
|
||||
if boxes is not None:
|
||||
seq.extend([','.join([str(int(b))
|
||||
for b in box[:-1]]) + '\n' for box in boxes])
|
||||
with open(os.path.join(submit_path, 'res_' +
|
||||
os.path.basename(img_files[i]).replace('_0.bin', '.txt')), 'w') as f:
|
||||
f.writelines(seq)
|
||||
|
||||
def eval_model(img_path, submit, save_flag=True):
|
||||
if os.path.exists(submit):
|
||||
shutil.rmtree(submit)
|
||||
os.mkdir(submit)
|
||||
|
||||
start_time = time.time()
|
||||
detect_dataset(img_path, submit)
|
||||
os.chdir(submit)
|
||||
res = subprocess.getoutput('zip -q submit.zip *.txt')
|
||||
res = subprocess.getoutput('mv submit.zip ../')
|
||||
os.chdir('../')
|
||||
res = subprocess.getoutput(
|
||||
'python ./evaluate/script.py -g=./evaluate/gt.zip -s=./submit.zip')
|
||||
print(res)
|
||||
os.remove('./submit.zip')
|
||||
print('eval time is {}'.format(time.time() - start_time))
|
||||
|
||||
if not save_flag:
|
||||
shutil.rmtree(submit)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_model(args.result_path, './submit')
|
|
@ -0,0 +1,81 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||
RANK_TABLE_FILE=$(get_real_path $3)
|
||||
echo $DATASET_PATH
|
||||
echo $PRETRAINED_BACKBONE
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_FILE
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||
--is_distributed=1 \
|
||||
--lr=0.001 \
|
||||
--max_epoch=600 \
|
||||
--per_batch_size=8 \
|
||||
--lr_scheduler=my_lr > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CKPT_PATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
CKPT_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
echo $CKPT_PATH
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CKPT_PATH ]
|
||||
then
|
||||
echo "error: CKPT_PATH=$CKPT_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$3
|
||||
|
||||
export RANK_ID=$3
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./eval_standalone
|
||||
mkdir ./eval_standalone
|
||||
cp ../*.py ./eval_standalone
|
||||
cp -r ../src ./eval_standalone
|
||||
cp -r ../evaluate ./eval_standalone
|
||||
cp -r ../lanms ./eval_standalone
|
||||
cd ./eval_standalone || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py \
|
||||
--test_img_path=$DATASET_PATH \
|
||||
--checkpoint_path=$CKPT_PATH \
|
||||
--device_num=$DEVICE_ID > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,100 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 2 || $# -gt 3 ]]; then
|
||||
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
|
||||
device_id=0
|
||||
if [ $# == 3 ]; then
|
||||
device_id=$3
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer/ || exit
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
sh build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
../ascend310_infer/out/main --mindir_path=$model --input0_path=$data_path --device_id=$device_id &> infer.log
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
{
|
||||
cd .. || exit
|
||||
python3.7 postprocess.py --result_path=./scripts/result_Files &> ./acc.log &
|
||||
}
|
||||
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
cal_acc
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||
RANK_TABLE_FILE=$(get_real_path $3)
|
||||
echo $DATASET_PATH
|
||||
echo $PRETRAINED_BACKBONE
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$3
|
||||
|
||||
export RANK_ID=$3
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train_standalone
|
||||
mkdir ./train_standalone
|
||||
cp ../*.py ./train_standalone
|
||||
cp -r ../src ./train_standalone
|
||||
cd ./train_standalone || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--is_distributed=0 \
|
||||
--lr=0.001 \
|
||||
--max_epoch=600 \
|
||||
--per_batch_size=24 \
|
||||
--lr_scheduler=my_lr > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,476 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
import os
|
||||
|
||||
from shapely.geometry import Polygon
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
from src.distributed_sampler import DistributedSampler
|
||||
|
||||
|
||||
def cal_distance(x1, y1, x2, y2):
|
||||
'''calculate the Euclidean distance'''
|
||||
return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
|
||||
|
||||
|
||||
def move_points(vertices, index1, index2, r, coef):
|
||||
'''move the two points to shrink edge
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
index1 : offset of point1
|
||||
index2 : offset of point2
|
||||
r : [r1, r2, r3, r4] in paper
|
||||
coef : shrink ratio in paper
|
||||
Output:
|
||||
vertices: vertices where one edge has been shinked
|
||||
'''
|
||||
index1 = index1 % 4
|
||||
index2 = index2 % 4
|
||||
x1_index = index1 * 2 + 0
|
||||
y1_index = index1 * 2 + 1
|
||||
x2_index = index2 * 2 + 0
|
||||
y2_index = index2 * 2 + 1
|
||||
|
||||
r1 = r[index1]
|
||||
r2 = r[index2]
|
||||
length_x = vertices[x1_index] - vertices[x2_index]
|
||||
length_y = vertices[y1_index] - vertices[y2_index]
|
||||
length = cal_distance(
|
||||
vertices[x1_index],
|
||||
vertices[y1_index],
|
||||
vertices[x2_index],
|
||||
vertices[y2_index])
|
||||
if length > 1:
|
||||
ratio = (r1 * coef) / length
|
||||
vertices[x1_index] += ratio * (-length_x)
|
||||
vertices[y1_index] += ratio * (-length_y)
|
||||
ratio = (r2 * coef) / length
|
||||
vertices[x2_index] += ratio * length_x
|
||||
vertices[y2_index] += ratio * length_y
|
||||
return vertices
|
||||
|
||||
|
||||
def shrink_poly(vertices, coef=0.3):
|
||||
'''shrink the text region
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
coef : shrink ratio in paper
|
||||
Output:
|
||||
v : vertices of shrunk text region <numpy.ndarray, (8,)>
|
||||
'''
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
|
||||
r1 = min(cal_distance(x1, y1, x2, y2), cal_distance(x1, y1, x4, y4))
|
||||
r2 = min(cal_distance(x2, y2, x1, y1), cal_distance(x2, y2, x3, y3))
|
||||
r3 = min(cal_distance(x3, y3, x2, y2), cal_distance(x3, y3, x4, y4))
|
||||
r4 = min(cal_distance(x4, y4, x1, y1), cal_distance(x4, y4, x3, y3))
|
||||
r = [r1, r2, r3, r4]
|
||||
|
||||
# obtain offset to perform move_points() automatically
|
||||
if cal_distance(x1, y1, x2, y2) + cal_distance(x3, y3, x4, y4) > \
|
||||
cal_distance(x2, y2, x3, y3) + cal_distance(x1, y1, x4, y4):
|
||||
offset = 0 # two longer edges are (x1y1-x2y2) & (x3y3-x4y4)
|
||||
else:
|
||||
offset = 1 # two longer edges are (x2y2-x3y3) & (x4y4-x1y1)
|
||||
|
||||
v = vertices.copy()
|
||||
v = move_points(v, 0 + offset, 1 + offset, r, coef)
|
||||
v = move_points(v, 2 + offset, 3 + offset, r, coef)
|
||||
v = move_points(v, 1 + offset, 2 + offset, r, coef)
|
||||
v = move_points(v, 3 + offset, 4 + offset, r, coef)
|
||||
return v
|
||||
|
||||
|
||||
def get_rotate_mat(theta):
|
||||
'''positive theta value means rotate clockwise'''
|
||||
return np.array([[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]])
|
||||
|
||||
|
||||
def rotate_vertices(vertices, theta, anchor=None):
|
||||
'''rotate vertices around anchor
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
theta : angle in radian measure
|
||||
anchor : fixed position during rotation
|
||||
Output:
|
||||
rotated vertices <numpy.ndarray, (8,)>
|
||||
'''
|
||||
v = vertices.reshape((4, 2)).T
|
||||
if anchor is None:
|
||||
anchor = v[:, :1]
|
||||
rotate_mat = get_rotate_mat(theta)
|
||||
res = np.dot(rotate_mat, v - anchor)
|
||||
return (res + anchor).T.reshape(-1)
|
||||
|
||||
|
||||
def get_boundary(vertices):
|
||||
'''get the tight boundary around given vertices
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
Output:
|
||||
the boundary
|
||||
'''
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
|
||||
x_min = min(x1, x2, x3, x4)
|
||||
x_max = max(x1, x2, x3, x4)
|
||||
y_min = min(y1, y2, y3, y4)
|
||||
y_max = max(y1, y2, y3, y4)
|
||||
return x_min, x_max, y_min, y_max
|
||||
|
||||
|
||||
def cal_error(vertices):
|
||||
'''default orientation is x1y1 : left-top, x2y2 : right-top, x3y3 : right-bot, x4y4 : left-bot
|
||||
calculate the difference between the vertices orientation and default orientation
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
Output:
|
||||
err : difference measure
|
||||
'''
|
||||
x_min, x_max, y_min, y_max = get_boundary(vertices)
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = vertices
|
||||
err = cal_distance(x1, y1, x_min, y_min) + \
|
||||
cal_distance(x2, y2, x_max, y_min) + \
|
||||
cal_distance(x3, y3, x_max, y_max) + \
|
||||
cal_distance(x4, y4, x_min, y_max)
|
||||
return err
|
||||
|
||||
|
||||
def find_min_rect_angle(vertices):
|
||||
'''find the best angle to rotate poly and obtain min rectangle
|
||||
Input:
|
||||
vertices: vertices of text region <numpy.ndarray, (8,)>
|
||||
Output:
|
||||
the best angle <radian measure>
|
||||
'''
|
||||
angle_interval = 1
|
||||
angle_list = list(range(-90, 90, angle_interval))
|
||||
area_list = []
|
||||
for theta in angle_list:
|
||||
rotated = rotate_vertices(vertices, theta / 180 * math.pi)
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = rotated
|
||||
temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
|
||||
(max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
|
||||
area_list.append(temp_area)
|
||||
|
||||
sorted_area_index = sorted(
|
||||
list(
|
||||
range(
|
||||
len(area_list))),
|
||||
key=lambda k: area_list[k])
|
||||
min_error = float('inf')
|
||||
best_index = -1
|
||||
rank_num = 10
|
||||
# find the best angle with correct orientation
|
||||
for index in sorted_area_index[:rank_num]:
|
||||
rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
|
||||
temp_error = cal_error(rotated)
|
||||
if temp_error < min_error:
|
||||
min_error = temp_error
|
||||
best_index = index
|
||||
return angle_list[best_index] / 180 * math.pi
|
||||
|
||||
|
||||
def is_cross_text(start_loc, length, vertices):
|
||||
'''check if the crop image crosses text regions
|
||||
Input:
|
||||
start_loc: left-top position
|
||||
length : length of crop image
|
||||
vertices : vertices of text regions <numpy.ndarray, (n,8)>
|
||||
Output:
|
||||
True if crop image crosses text region
|
||||
'''
|
||||
if vertices.size == 0:
|
||||
return False
|
||||
start_w, start_h = start_loc
|
||||
a = np.array([start_w, start_h, start_w +
|
||||
length, start_h, start_w +
|
||||
length, start_h +
|
||||
length, start_w, start_h +
|
||||
length]).reshape((4, 2))
|
||||
p1 = Polygon(a).convex_hull
|
||||
for vertice in vertices:
|
||||
p2 = Polygon(vertice.reshape((4, 2))).convex_hull
|
||||
inter = p1.intersection(p2).area
|
||||
if 0.01 <= inter / p2.area <= 0.99:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def crop_img(img, vertices, labels, length):
|
||||
'''crop img patches to obtain batch and augment
|
||||
Input:
|
||||
img : PIL Image
|
||||
vertices : vertices of text regions <numpy.ndarray, (n,8)>
|
||||
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
|
||||
length : length of cropped image region
|
||||
Output:
|
||||
region : cropped image region
|
||||
new_vertices: new vertices in cropped region
|
||||
'''
|
||||
h, w = img.height, img.width
|
||||
# confirm the shortest side of image >= length
|
||||
if h >= w and w < length:
|
||||
img = img.resize((length, int(h * length / w)), Image.BILINEAR)
|
||||
elif h < w and h < length:
|
||||
img = img.resize((int(w * length / h), length), Image.BILINEAR)
|
||||
ratio_w = img.width / w
|
||||
ratio_h = img.height / h
|
||||
assert (ratio_w >= 1 and ratio_h >= 1)
|
||||
|
||||
new_vertices = np.zeros(vertices.shape)
|
||||
if vertices.size > 0:
|
||||
new_vertices[:, [0, 2, 4, 6]] = vertices[:, [0, 2, 4, 6]] * ratio_w
|
||||
new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * ratio_h
|
||||
|
||||
# find random position
|
||||
remain_h = img.height - length
|
||||
remain_w = img.width - length
|
||||
flag = True
|
||||
cnt = 0
|
||||
while flag and cnt < 1000:
|
||||
cnt += 1
|
||||
start_w = int(np.random.rand() * remain_w)
|
||||
start_h = int(np.random.rand() * remain_h)
|
||||
flag = is_cross_text([start_w, start_h], length,
|
||||
new_vertices[labels == 1, :])
|
||||
box = (start_w, start_h, start_w + length, start_h + length)
|
||||
region = img.crop(box)
|
||||
if new_vertices.size == 0:
|
||||
return region, new_vertices
|
||||
|
||||
new_vertices[:, [0, 2, 4, 6]] -= start_w
|
||||
new_vertices[:, [1, 3, 5, 7]] -= start_h
|
||||
return region, new_vertices
|
||||
|
||||
|
||||
def rotate_all_pixels(rotate_mat, anchor_x, anchor_y, length):
|
||||
'''get rotated locations of all pixels for next stages
|
||||
Input:
|
||||
rotate_mat: rotatation matrix
|
||||
anchor_x : fixed x position
|
||||
anchor_y : fixed y position
|
||||
length : length of image
|
||||
Output:
|
||||
rotated_x : rotated x positions <numpy.ndarray, (length,length)>
|
||||
rotated_y : rotated y positions <numpy.ndarray, (length,length)>
|
||||
'''
|
||||
x = np.arange(length)
|
||||
y = np.arange(length)
|
||||
x, y = np.meshgrid(x, y)
|
||||
x_lin = x.reshape((1, x.size))
|
||||
y_lin = y.reshape((1, x.size))
|
||||
coord_mat = np.concatenate((x_lin, y_lin), 0)
|
||||
rotated_coord = np.matmul(rotate_mat.astype(np.float16),
|
||||
(coord_mat - np.array([[anchor_x],
|
||||
[anchor_y]])).astype(np.float16)) + np.array([[anchor_x],
|
||||
[anchor_y]])
|
||||
rotated_x = rotated_coord[0, :].reshape(x.shape)
|
||||
rotated_y = rotated_coord[1, :].reshape(y.shape)
|
||||
return rotated_x, rotated_y
|
||||
|
||||
|
||||
def adjust_height(img, vertices, ratio=0.2):
|
||||
'''adjust height of image to aug data
|
||||
Input:
|
||||
img : PIL Image
|
||||
vertices : vertices of text regions <numpy.ndarray, (n,8)>
|
||||
ratio : height changes in [0.8, 1.2]
|
||||
Output:
|
||||
img : adjusted PIL Image
|
||||
new_vertices: adjusted vertices
|
||||
'''
|
||||
ratio_h = 1 + ratio * (np.random.rand() * 2 - 1)
|
||||
old_h = img.height
|
||||
new_h = int(np.around(old_h * ratio_h))
|
||||
img = img.resize((img.width, new_h), Image.BILINEAR)
|
||||
|
||||
new_vertices = vertices.copy()
|
||||
if vertices.size > 0:
|
||||
new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * (new_h / old_h)
|
||||
return img, new_vertices
|
||||
|
||||
|
||||
def rotate_img(img, vertices, angle_range=10):
|
||||
'''rotate image [-10, 10] degree to aug data
|
||||
Input:
|
||||
img : PIL Image
|
||||
vertices : vertices of text regions <numpy.ndarray, (n,8)>
|
||||
angle_range : rotate range
|
||||
Output:
|
||||
img : rotated PIL Image
|
||||
new_vertices: rotated vertices
|
||||
'''
|
||||
center_x = (img.width - 1) / 2
|
||||
center_y = (img.height - 1) / 2
|
||||
angle = angle_range * (np.random.rand() * 2 - 1)
|
||||
img = img.rotate(angle, Image.BILINEAR)
|
||||
new_vertices = np.zeros(vertices.shape)
|
||||
for i, vertice in enumerate(vertices):
|
||||
new_vertices[i, :] = rotate_vertices(
|
||||
vertice, -angle / 180 * math.pi, np.array([[center_x], [center_y]]))
|
||||
return img, new_vertices
|
||||
|
||||
|
||||
def get_score_geo(img, vertices, labels, scale, length):
|
||||
'''generate score gt and geometry gt
|
||||
Input:
|
||||
img : PIL Image
|
||||
vertices: vertices of text regions <numpy.ndarray, (n,8)>
|
||||
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
|
||||
scale : feature map / image
|
||||
length : image length
|
||||
Output:
|
||||
score gt, geo gt, ignored
|
||||
'''
|
||||
score_map = np.zeros(
|
||||
(int(img.height * scale), int(img.width * scale), 1), np.float32)
|
||||
geo_map = np.zeros(
|
||||
(int(img.height * scale), int(img.width * scale), 5), np.float32)
|
||||
ignored_map = np.zeros(
|
||||
(int(img.height * scale), int(img.width * scale), 1), np.float32)
|
||||
|
||||
index = np.arange(0, length, int(1 / scale))
|
||||
index_x, index_y = np.meshgrid(index, index)
|
||||
ignored_polys = []
|
||||
polys = []
|
||||
|
||||
for i, vertice in enumerate(vertices):
|
||||
if labels[i] == 0:
|
||||
ignored_polys.append(np.around(scale * vertice.reshape((4, 2))).astype(np.int32))
|
||||
continue
|
||||
|
||||
poly = np.around(scale * shrink_poly(vertice).reshape((4, 2))).astype(np.int32)
|
||||
polys.append(poly)
|
||||
temp_mask = np.zeros(score_map.shape[:-1], np.float32)
|
||||
cv2.fillPoly(temp_mask, [poly], 1)
|
||||
|
||||
theta = find_min_rect_angle(vertice)
|
||||
rotate_mat = get_rotate_mat(theta)
|
||||
|
||||
rotated_vertices = rotate_vertices(vertice, theta)
|
||||
x_min, x_max, y_min, y_max = get_boundary(rotated_vertices)
|
||||
rotated_x, rotated_y = rotate_all_pixels(rotate_mat, vertice[0], vertice[1], length)
|
||||
|
||||
d1 = rotated_y - y_min
|
||||
d1[d1 < 0] = 0
|
||||
d2 = y_max - rotated_y
|
||||
d2[d2 < 0] = 0
|
||||
d3 = rotated_x - x_min
|
||||
d3[d3 < 0] = 0
|
||||
d4 = x_max - rotated_x
|
||||
d4[d4 < 0] = 0
|
||||
geo_map[:, :, 0] += d1[index_y, index_x] * temp_mask
|
||||
geo_map[:, :, 1] += d2[index_y, index_x] * temp_mask
|
||||
geo_map[:, :, 2] += d3[index_y, index_x] * temp_mask
|
||||
geo_map[:, :, 3] += d4[index_y, index_x] * temp_mask
|
||||
geo_map[:, :, 4] += theta * temp_mask
|
||||
|
||||
cv2.fillPoly(ignored_map, ignored_polys, 1)
|
||||
cv2.fillPoly(score_map, polys, 1)
|
||||
return score_map, geo_map, ignored_map
|
||||
|
||||
|
||||
def extract_vertices(lines):
|
||||
'''extract vertices info from txt lines
|
||||
Input:
|
||||
lines : list of string info
|
||||
Output:
|
||||
vertices: vertices of text regions <numpy.ndarray, (n,8)>
|
||||
labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
|
||||
'''
|
||||
labels = []
|
||||
vertices = []
|
||||
for line in lines:
|
||||
vertices.append(list(map(int, line.rstrip('\n').lstrip('\ufeff').split(',')[:8])))
|
||||
label = 0 if '###' in line else 1
|
||||
labels.append(label)
|
||||
return np.array(vertices), np.array(labels)
|
||||
|
||||
|
||||
class ICDAREASTDataset:
|
||||
def __init__(self, img_path, gt_path, scale=0.25, length=512):
|
||||
super(ICDAREASTDataset, self).__init__()
|
||||
self.img_files = [os.path.join(
|
||||
img_path,
|
||||
img_file) for img_file in sorted(os.listdir(img_path))]
|
||||
self.gt_files = [
|
||||
os.path.join(
|
||||
gt_path,
|
||||
gt_file) for gt_file in sorted(
|
||||
os.listdir(gt_path))]
|
||||
self.scale = scale
|
||||
self.length = length
|
||||
|
||||
def __getitem__(self, index):
|
||||
with open(self.gt_files[index], 'r') as f:
|
||||
lines = f.readlines()
|
||||
vertices, labels = extract_vertices(lines)
|
||||
|
||||
img = Image.open(self.img_files[index])
|
||||
img, vertices = adjust_height(img, vertices)
|
||||
img, vertices = rotate_img(img, vertices)
|
||||
img, vertices = crop_img(img, vertices, labels, self.length)
|
||||
score_map, geo_map, ignored_map = get_score_geo(
|
||||
img, vertices, labels, self.scale, self.length)
|
||||
score_map = score_map.transpose(2, 0, 1)
|
||||
ignored_map = ignored_map.transpose(2, 0, 1)
|
||||
geo_map = geo_map.transpose(2, 0, 1)
|
||||
if np.sum(score_map) < 1:
|
||||
score_map[0, 0, 0] = 1
|
||||
return img, score_map, geo_map, ignored_map
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
|
||||
def create_east_dataset(
|
||||
img_root,
|
||||
txt_root,
|
||||
batch_size,
|
||||
device_num,
|
||||
rank,
|
||||
is_training=True):
|
||||
east_data = ICDAREASTDataset(img_path=img_root, gt_path=txt_root)
|
||||
distributed_sampler = DistributedSampler(
|
||||
len(east_data), device_num, rank, shuffle=True)
|
||||
|
||||
trans_list = [CV.RandomColorAdjust(0.5, 0.5, 0.5, 0.25),
|
||||
CV.Rescale(1 / 255.0, 0),
|
||||
CV.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
CV.HWC2CHW()]
|
||||
if is_training:
|
||||
dataset_column_names = [
|
||||
"image",
|
||||
"score_map",
|
||||
"geo_map",
|
||||
"training_mask"]
|
||||
ds = de.GeneratorDataset(
|
||||
east_data,
|
||||
column_names=dataset_column_names,
|
||||
num_parallel_workers=32,
|
||||
sampler=distributed_sampler)
|
||||
ds = ds.map(
|
||||
operations=trans_list,
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=8,
|
||||
python_multiprocessing=True)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=True)
|
||||
|
||||
return ds, len(east_data)
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True):
|
||||
if num_replicas is None:
|
||||
print(
|
||||
"***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print(
|
||||
"***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(
|
||||
dataset_size *
|
||||
1.0 /
|
||||
self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(
|
||||
seed=self.epoch).permutation(
|
||||
self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as
|
||||
# index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,363 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
|
||||
|
||||
def _conv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
weights = 'ones'
|
||||
layers = []
|
||||
layers += [nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
pad_mode=pad_mode,
|
||||
weight_init=weights,
|
||||
has_bias=False)]
|
||||
layers += [nn.BatchNorm2d(out_channels)]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
class VGG16FeatureExtraction(nn.Cell):
|
||||
"""VGG16FeatureExtraction for deeptext"""
|
||||
|
||||
def __init__(self):
|
||||
super(VGG16FeatureExtraction, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv1_1 = _conv(
|
||||
in_channels=3,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv1_2 = _conv(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
self.conv2_1 = _conv(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv2_2 = _conv(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
self.conv3_1 = _conv(
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv3_2 = _conv(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv3_3 = _conv(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
self.conv4_1 = _conv(
|
||||
in_channels=256,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv4_2 = _conv(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv4_3 = _conv(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
self.conv5_1 = _conv(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv5_2 = _conv(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv5_3 = _conv(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, out):
|
||||
""" Construction of VGG """
|
||||
f_0 = out
|
||||
out = self.cast(out, mstype.float32)
|
||||
out = self.conv1_1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv1_2(out)
|
||||
out = self.relu(out)
|
||||
out = self.max_pool(out)
|
||||
|
||||
out = self.conv2_1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2_2(out)
|
||||
out = self.relu(out)
|
||||
out = self.max_pool(out)
|
||||
f_2 = out
|
||||
|
||||
out = self.conv3_1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3_2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3_3(out)
|
||||
out = self.relu(out)
|
||||
out = self.max_pool(out)
|
||||
f_3 = out
|
||||
|
||||
out = self.conv4_1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv4_2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv4_3(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.max_pool(out)
|
||||
f_4 = out
|
||||
out = self.conv5_1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv5_2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv5_3(out)
|
||||
out = self.relu(out)
|
||||
out = self.max_pool(out)
|
||||
f_5 = out
|
||||
|
||||
return f_0, f_2, f_3, f_4, f_5
|
||||
|
||||
|
||||
class Merge(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Merge, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(1024, 128, 1, has_bias=True)
|
||||
self.bn1 = nn.BatchNorm2d(128)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(
|
||||
128,
|
||||
128,
|
||||
3,
|
||||
padding=1,
|
||||
pad_mode='pad',
|
||||
has_bias=True)
|
||||
self.bn2 = nn.BatchNorm2d(128)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
self.conv3 = nn.Conv2d(384, 64, 1, has_bias=True)
|
||||
self.bn3 = nn.BatchNorm2d(64)
|
||||
self.relu3 = nn.ReLU()
|
||||
self.conv4 = nn.Conv2d(
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
padding=1,
|
||||
pad_mode='pad',
|
||||
has_bias=True)
|
||||
self.bn4 = nn.BatchNorm2d(64)
|
||||
self.relu4 = nn.ReLU()
|
||||
|
||||
self.conv5 = nn.Conv2d(192, 32, 1)
|
||||
self.bn5 = nn.BatchNorm2d(32)
|
||||
self.relu5 = nn.ReLU()
|
||||
self.conv6 = nn.Conv2d(
|
||||
32,
|
||||
32,
|
||||
3,
|
||||
padding=1,
|
||||
pad_mode='pad',
|
||||
has_bias=True)
|
||||
self.bn6 = nn.BatchNorm2d(32)
|
||||
self.relu6 = nn.ReLU()
|
||||
|
||||
self.conv7 = nn.Conv2d(
|
||||
32,
|
||||
32,
|
||||
3,
|
||||
padding=1,
|
||||
pad_mode='pad',
|
||||
has_bias=True)
|
||||
self.bn7 = nn.BatchNorm2d(32)
|
||||
self.relu7 = nn.ReLU()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, x, f1, f2, f3, f4):
|
||||
img_hight = P.Shape()(x)[2]
|
||||
img_width = P.Shape()(x)[3]
|
||||
|
||||
out = P.ResizeBilinear((img_hight / 16, img_width / 16), True)(f4)
|
||||
out = self.concat((out, f3))
|
||||
out = self.relu1(self.bn1(self.conv1(out)))
|
||||
out = self.relu2(self.bn2(self.conv2(out)))
|
||||
|
||||
out = P.ResizeBilinear((img_hight / 8, img_width / 8), True)(out)
|
||||
out = self.concat((out, f2))
|
||||
out = self.relu3(self.bn3(self.conv3(out)))
|
||||
out = self.relu4(self.bn4(self.conv4(out)))
|
||||
|
||||
out = P.ResizeBilinear((img_hight / 4, img_width / 4), True)(out)
|
||||
out = self.concat((out, f1))
|
||||
out = self.relu5(self.bn5(self.conv5(out)))
|
||||
out = self.relu6(self.bn6(self.conv6(out)))
|
||||
|
||||
out = self.relu7(self.bn7(self.conv7(out)))
|
||||
return out
|
||||
|
||||
|
||||
class Output(nn.Cell):
|
||||
def __init__(self, scope=512):
|
||||
super(Output, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 1, 1)
|
||||
self.sigmoid1 = nn.Sigmoid()
|
||||
self.conv2 = nn.Conv2d(32, 4, 1)
|
||||
self.sigmoid2 = nn.Sigmoid()
|
||||
self.conv3 = nn.Conv2d(32, 1, 1)
|
||||
self.sigmoid3 = nn.Sigmoid()
|
||||
self.scope = scope
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.PI = 3.1415926535898
|
||||
|
||||
def construct(self, x):
|
||||
score = self.sigmoid1(self.conv1(x))
|
||||
loc = self.sigmoid2(self.conv2(x)) * self.scope
|
||||
angle = (self.sigmoid3(self.conv3(x)) - 0.5) * self.PI
|
||||
geo = self.concat((loc, angle))
|
||||
return score, geo
|
||||
|
||||
|
||||
class EAST(nn.Cell):
|
||||
def __init__(self):
|
||||
super(EAST, self).__init__()
|
||||
self.extractor = VGG16FeatureExtraction()
|
||||
self.merge = Merge()
|
||||
self.output = Output()
|
||||
|
||||
def construct(self, x_1):
|
||||
f_0, f_1, f_2, f_3, f_4 = self.extractor(x_1)
|
||||
x_1 = self.merge(f_0, f_1, f_2, f_3, f_4)
|
||||
score, geo = self.output(x_1)
|
||||
|
||||
return score, geo
|
||||
|
||||
|
||||
class DiceCoefficient(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DiceCoefficient, self).__init__()
|
||||
self.sum = P.ReduceSum()
|
||||
self.eps = 1e-5
|
||||
|
||||
def construct(self, true_cls, pred_cls):
|
||||
intersection = self.sum(true_cls * pred_cls, ())
|
||||
union = self.sum(true_cls, ()) + self.sum(pred_cls, ()) + self.eps
|
||||
loss = 1. - (2 * intersection / union)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class MyMin(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyMin, self).__init__()
|
||||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, a, b):
|
||||
return (a + b - self.abs(a - b)) / 2
|
||||
|
||||
|
||||
class EastLossBlock(nn.Cell):
|
||||
def __init__(self):
|
||||
super(EastLossBlock, self).__init__()
|
||||
self.split = P.Split(1, 5)
|
||||
self.min = MyMin()
|
||||
self.log = P.Log()
|
||||
self.cos = P.Cos()
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.sum = P.ReduceSum()
|
||||
self.eps = 1e-5
|
||||
self.dice = DiceCoefficient()
|
||||
|
||||
def construct(
|
||||
self,
|
||||
y_true_cls,
|
||||
y_pred_cls,
|
||||
y_true_geo,
|
||||
y_pred_geo,
|
||||
training_mask):
|
||||
ans = self.sum(y_true_cls)
|
||||
classification_loss = self.dice(
|
||||
y_true_cls, y_pred_cls * (1 - training_mask))
|
||||
|
||||
# n * 5 * h * w
|
||||
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = self.split(y_true_geo)
|
||||
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = self.split(y_pred_geo)
|
||||
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
|
||||
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
|
||||
w_union = self.min(d2_gt, d2_pred) + self.min(d4_gt, d4_pred)
|
||||
h_union = self.min(d1_gt, d1_pred) + self.min(d3_gt, d3_pred)
|
||||
|
||||
area_intersect = w_union * h_union
|
||||
area_union = area_gt + area_pred - area_intersect
|
||||
iou_loss_map = -self.log((area_intersect + 1.0) /
|
||||
(area_union + 1.0)) # iou_loss_map
|
||||
angle_loss_map = 1 - self.cos(theta_pred - theta_gt) # angle_loss_map
|
||||
|
||||
angle_loss = self.sum(angle_loss_map * y_true_cls) / ans
|
||||
iou_loss = self.sum(iou_loss_map * y_true_cls) / ans
|
||||
geo_loss = 10 * angle_loss + iou_loss
|
||||
|
||||
return geo_loss + classification_loss
|
||||
|
||||
|
||||
class EastWithLossCell(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(EastWithLossCell, self).__init__()
|
||||
self.east_network = network
|
||||
self.loss = EastLossBlock()
|
||||
|
||||
def construct(self, img, true_cls, true_geo, training_mask):
|
||||
socre, geometry = self.east_network(img)
|
||||
loss = self.loss(
|
||||
true_cls,
|
||||
socre,
|
||||
true_geo,
|
||||
geometry,
|
||||
training_mask)
|
||||
return loss
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameter init."""
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import Initializer as MeInitializer
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function (`nn.functional` name)
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = [
|
||||
'linear',
|
||||
'conv1d',
|
||||
'conv2d',
|
||||
'conv3d',
|
||||
'conv_transpose1d',
|
||||
'conv_transpose2d',
|
||||
'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError(
|
||||
"negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of 'num' and 'arr'."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
|
||||
def _calculate_correct_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError(
|
||||
"Mode {} not supported, please use one of {}".format(
|
||||
mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
uniform distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `Tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
fan = _calculate_correct_fan(arr, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
# Calculate uniform bounds from standard deviation
|
||||
bound = math.sqrt(3.0) * std
|
||||
return np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(arr):
|
||||
"""Calculate fan in and fan out."""
|
||||
dimensions = len(arr.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError(
|
||||
"Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = arr.shape[1]
|
||||
num_output_fmaps = arr.shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
class KaimingUniform(MeInitializer):
|
||||
"""Kaiming uniform initializer."""
|
||||
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingUniform, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, array):
|
||||
tmp = kaiming_uniform_(array, self.a, self.mode, self.nonlinearity)
|
||||
_assignment(array, tmp)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""Initialize parameter."""
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(
|
||||
init.initializer(
|
||||
KaimingUniform(
|
||||
a=math.sqrt(5)),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
self.log_fn = ''
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + \
|
||||
'_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
file = logging.FileHandler(self.log_fn)
|
||||
file.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
file.setFormatter(formatter)
|
||||
self.addHandler(file)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER('east_vgg', rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate scheduler."""
|
||||
import math
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""Linear learning rate."""
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_step_lr(
|
||||
lr,
|
||||
lr_epochs,
|
||||
steps_per_epoch,
|
||||
warmup_epochs,
|
||||
max_epoch,
|
||||
gamma=0.1):
|
||||
"""Warmup step learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma ** milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(
|
||||
lr,
|
||||
milestones,
|
||||
steps_per_epoch,
|
||||
0,
|
||||
max_epoch,
|
||||
gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(
|
||||
lr,
|
||||
lr_epochs,
|
||||
steps_per_epoch,
|
||||
max_epoch,
|
||||
gamma=gamma)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(
|
||||
lr,
|
||||
steps_per_epoch,
|
||||
warmup_epochs,
|
||||
max_epoch,
|
||||
t_max,
|
||||
eta_min=0):
|
||||
"""Cosine annealing learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * \
|
||||
(1. + math.cos(math.pi * last_epoch / t_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_v2(
|
||||
lr,
|
||||
steps_per_epoch,
|
||||
warmup_epochs,
|
||||
max_epoch,
|
||||
t_max,
|
||||
eta_min=0):
|
||||
"""Cosine annealing learning rate V2."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
last_lr = 0
|
||||
last_epoch_v1 = 0
|
||||
|
||||
t_max_v2 = int(max_epoch * 1 / 3)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
if i < total_steps * 2 / 3:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + \
|
||||
math.cos(math.pi * last_epoch / t_max)) / 2
|
||||
last_lr = lr
|
||||
last_epoch_v1 = last_epoch
|
||||
else:
|
||||
base_lr = last_lr
|
||||
last_epoch = last_epoch - last_epoch_v1
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + \
|
||||
math.cos(math.pi * last_epoch / t_max_v2)) / 2
|
||||
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_sample(
|
||||
lr,
|
||||
steps_per_epoch,
|
||||
warmup_epochs,
|
||||
max_epoch,
|
||||
t_max,
|
||||
eta_min=0):
|
||||
"""Warmup cosine annealing learning rate."""
|
||||
start_sample_epoch = 60
|
||||
step_sample = 2
|
||||
tobe_sampled_epoch = 60
|
||||
end_sampled_epoch = start_sample_epoch + step_sample * tobe_sampled_epoch
|
||||
max_sampled_epoch = max_epoch + tobe_sampled_epoch
|
||||
t_max = max_sampled_epoch
|
||||
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
|
||||
for i in range(total_sampled_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if last_epoch in range(
|
||||
start_sample_epoch,
|
||||
end_sampled_epoch,
|
||||
step_sample):
|
||||
continue
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * \
|
||||
(1. + math.cos(math.pi * last_epoch / t_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def my_lr(max_epoch, steps_per_epoch, per_step=2, gamma=0.1, lr=0.001):
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * max_epoch
|
||||
n = total_steps // per_step
|
||||
for i in range(max_epoch * steps_per_epoch):
|
||||
if i % n == 0 and i != 0:
|
||||
lr = lr * gamma
|
||||
lr_each_step.append(lr)
|
||||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate."""
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_v2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'my_lr':
|
||||
lr = my_lr(
|
||||
args.max_epoch,
|
||||
args.steps_per_epoch,
|
||||
args.per_step,
|
||||
args.lr_gamma,
|
||||
args.lr)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
return lr
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
def default_wd_filter(x):
|
||||
"""default weight decay filter."""
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
return False
|
||||
if parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not
|
||||
# include BN
|
||||
return False
|
||||
if parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not
|
||||
# include BN
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
"""Param groups for optimizer."""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not
|
||||
# include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not
|
||||
# include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0},
|
||||
{'params': decay_params}]
|
||||
|
||||
|
||||
class ShapeRecord:
|
||||
"""Log image shape."""
|
||||
|
||||
def __init__(self):
|
||||
self.shape_record = {
|
||||
416: 0,
|
||||
448: 0,
|
||||
480: 0,
|
||||
512: 0,
|
||||
544: 0,
|
||||
576: 0,
|
||||
608: 0,
|
||||
640: 0,
|
||||
672: 0,
|
||||
704: 0,
|
||||
736: 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def set(self, shape):
|
||||
if len(shape) > 1:
|
||||
shape = shape[0]
|
||||
shape = int(shape)
|
||||
self.shape_record[shape] += 1
|
||||
self.shape_record['total'] += 1
|
||||
|
||||
def show(self, logger):
|
||||
for key in self.shape_record:
|
||||
rate = self.shape_record[key] / float(self.shape_record['total'])
|
||||
logger.info('shape {}: {:.2f}%'.format(key, rate * 100))
|
|
@ -0,0 +1,318 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.optim.adam import Adam
|
||||
from mindspore import Tensor, Model
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
import mindspore as ms
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.profiler.profiling import Profiler
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
|
||||
from src.util import AverageMeter, get_param_groups
|
||||
from src.east import EAST, EastWithLossCell
|
||||
from src.logger import get_logger
|
||||
from src.initializer import default_recurisive_init
|
||||
from src.dataset import create_east_dataset
|
||||
from src.lr_scheduler import get_lr
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore icdar training')
|
||||
|
||||
# device related
|
||||
parser.add_argument(
|
||||
'--device_target',
|
||||
type=str,
|
||||
default='Ascend',
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument(
|
||||
'--device_id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
default='/data/icdar2015/Training/',
|
||||
type=str,
|
||||
help='Train dataset directory.')
|
||||
parser.add_argument(
|
||||
'--per_batch_size',
|
||||
default=24,
|
||||
type=int,
|
||||
help='Batch size for Training. Default: 24.')
|
||||
parser.add_argument(
|
||||
'--outputs_dir',
|
||||
default='outputs/',
|
||||
type=str,
|
||||
help='output dir. Default: outputs/')
|
||||
|
||||
# network related
|
||||
parser.add_argument(
|
||||
'--pretrained_backbone',
|
||||
default='/data/vgg/0-150_5004.ckpt',
|
||||
type=str,
|
||||
help='The ckpt file of ResNet. Default: "".')
|
||||
parser.add_argument(
|
||||
'--resume_east',
|
||||
default='',
|
||||
type=str,
|
||||
help='The ckpt file of EAST, which used to fine tune. Default: ""')
|
||||
|
||||
# optimizer and lr related
|
||||
parser.add_argument(
|
||||
'--lr_scheduler',
|
||||
default='my_lr',
|
||||
type=str,
|
||||
help='Learning rate scheduler, options: exponential, cosine_annealing. Default: cosine_annealing')
|
||||
parser.add_argument('--lr', default=0.001, type=float,
|
||||
help='Learning rate. Default: 0.001')
|
||||
parser.add_argument('--per_step', default=2, type=float,
|
||||
help='Learning rate change times. Default: 2')
|
||||
parser.add_argument(
|
||||
'--lr_gamma',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
|
||||
parser.add_argument(
|
||||
'--eta_min',
|
||||
type=float,
|
||||
default=0.,
|
||||
help='Eta_min in cosine_annealing scheduler. Default: 0.')
|
||||
parser.add_argument(
|
||||
'--t_max',
|
||||
type=int,
|
||||
default=100,
|
||||
help='T-max in cosine_annealing scheduler. Default: 100')
|
||||
parser.add_argument('--max_epoch', type=int, default=600,
|
||||
help='Max epoch num to train the model. Default: 100')
|
||||
parser.add_argument(
|
||||
'--warmup_epochs',
|
||||
default=6,
|
||||
type=float,
|
||||
help='Warmup epochs. Default: 6')
|
||||
parser.add_argument(
|
||||
'--weight_decay',
|
||||
type=float,
|
||||
default=0.0005,
|
||||
help='Weight decay factor. Default: 0.0005')
|
||||
|
||||
# loss related
|
||||
parser.add_argument('--loss_scale', type=int, default=1,
|
||||
help='Static loss scale. Default: 64')
|
||||
parser.add_argument(
|
||||
'--lr_epochs',
|
||||
type=str,
|
||||
default='7,7',
|
||||
help='Epoch of changing of lr changing, split with ",". Default: 220,250')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=10,
|
||||
help='Logging interval steps. Default: 100')
|
||||
parser.add_argument(
|
||||
'--ckpt_path',
|
||||
type=str,
|
||||
default='outputs/',
|
||||
help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument(
|
||||
'--ckpt_interval',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Save checkpoint interval. Default: None')
|
||||
|
||||
parser.add_argument(
|
||||
'--is_save_on_master',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument(
|
||||
'--is_distributed',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Local rank of distributed. Default: 0')
|
||||
parser.add_argument('--group_size', type=int, default=1,
|
||||
help='World size of device. Default: 1')
|
||||
|
||||
# profiler init
|
||||
parser.add_argument(
|
||||
'--need_profiler',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Whether use profiler. 0 for no, 1 for yes. Default: 0')
|
||||
# modelArts
|
||||
parser.add_argument(
|
||||
'--is_modelArts',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Trainning in modelArts or not, 1 for yes, 0 for no. Default: 0')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
args.rank = args.device_id
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
args.rank = int(os.getenv('DEVICE_ID'))
|
||||
args.group_size = int(os.getenv('RANK_SIZE'))
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target,
|
||||
save_graphs=False,
|
||||
device_id=args.rank)
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model
|
||||
# parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
if args.is_modelArts:
|
||||
import moxing as mox
|
||||
|
||||
local_data_url = os.path.join('/cache/data', str(args.rank))
|
||||
local_ckpt_url = os.path.join('/cache/ckpt', str(args.rank))
|
||||
local_ckpt_url = os.path.join(local_ckpt_url, 'backbone.ckpt')
|
||||
|
||||
mox.file.rename(args.pretrained_backbone, local_ckpt_url)
|
||||
args.pretrained_backbone = local_ckpt_url
|
||||
|
||||
mox.file.copy_parallel(args.data_dir, local_data_url)
|
||||
args.data_dir = local_data_url
|
||||
|
||||
args.outputs_dir = os.path.join('/cache', args.outputs_dir)
|
||||
|
||||
args.data_root = os.path.abspath(os.path.join(args.data_dir, 'image'))
|
||||
args.txt_root = os.path.abspath(os.path.join(args.data_dir, 'groundTruth'))
|
||||
|
||||
outputs_dir = os.path.join(args.outputs_dir, str(args.rank))
|
||||
args.outputs_dir = os.path.join(
|
||||
args.outputs_dir,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.need_profiler:
|
||||
profiler = Profiler(
|
||||
output_path=args.outputs_dir,
|
||||
is_detail=True,
|
||||
is_show_op_path=True)
|
||||
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
degree = 1
|
||||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
degree = int(os.getenv('RANK_SIZE'))
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=parallel_mode,
|
||||
gradients_mean=True,
|
||||
device_num=degree)
|
||||
|
||||
network = EAST()
|
||||
# default is kaiming-normal
|
||||
default_recurisive_init(network)
|
||||
|
||||
# load pretrained_backbone
|
||||
if args.pretrained_backbone:
|
||||
parm_dict = load_checkpoint(args.pretrained_backbone)
|
||||
load_param_into_net(network, parm_dict)
|
||||
args.logger.info('finish load pretrained_backbone')
|
||||
|
||||
network = EastWithLossCell(network)
|
||||
if args.resume_east:
|
||||
parm_dict = load_checkpoint(args.resume_east)
|
||||
load_param_into_net(network, parm_dict)
|
||||
args.logger.info('finish get resume east')
|
||||
|
||||
args.logger.info('finish get network')
|
||||
|
||||
ds, data_size = create_east_dataset(img_root=args.data_root, txt_root=args.txt_root, batch_size=args.per_batch_size,
|
||||
device_num=args.group_size, rank=args.rank, is_training=True)
|
||||
args.logger.info('Finish loading dataset')
|
||||
|
||||
args.steps_per_epoch = int(
|
||||
data_size /
|
||||
args.per_batch_size /
|
||||
args.group_size)
|
||||
|
||||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
|
||||
# get learnning rate
|
||||
lr = get_lr(args)
|
||||
opt = Adam(
|
||||
params=get_param_groups(network),
|
||||
learning_rate=Tensor(
|
||||
lr,
|
||||
ms.float32))
|
||||
loss_scale = FixedLossScaleManager(1.0, drop_overflow_update=True)
|
||||
model = Model(network, optimizer=opt, loss_scale_manager=loss_scale)
|
||||
network.set_train()
|
||||
|
||||
# save the network model and parameters for subsequence fine-tuning
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=100,
|
||||
keep_checkpoint_max=1)
|
||||
# group layers into an object with training and evaluation features
|
||||
save_ckpt_path = os.path.join(
|
||||
args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpoint_cb = ModelCheckpoint(
|
||||
prefix="checkpoint_east",
|
||||
directory=save_ckpt_path,
|
||||
config=config_ck)
|
||||
callback = []
|
||||
if args.rank == 0:
|
||||
callback = [
|
||||
TimeMonitor(
|
||||
data_size=data_size),
|
||||
LossMonitor(),
|
||||
ckpoint_cb]
|
||||
|
||||
save_ckpt_path = os.path.join(
|
||||
args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
model.train(
|
||||
args.max_epoch,
|
||||
ds,
|
||||
callbacks=callback,
|
||||
dataset_sink_mode=False)
|
||||
args.logger.info('==========end training===============')
|
Loading…
Reference in New Issue