!22826 modify etsnet to psenet

Merge pull request !22826 from anzhengqi/update-psenet
This commit is contained in:
i-robot 2021-09-04 02:53:13 +00:00 committed by Gitee
commit a7e34a0526
16 changed files with 23 additions and 21 deletions

View File

@ -176,7 +176,7 @@ bash scripts/run_eval_ascend.sh
## [Script and Sample Code](#contents)
```path
└── PSENet
└── psenet
├── export.py // export mindir file
├── postprocess.py // 310 Inference post-processing script
├── __init__.py
@ -196,10 +196,10 @@ bash scripts/run_eval_ascend.sh
├──device_adapter.py // Device Config
├──local_adapter.py // local device config
├── dataset.py // creating dataset
├── ETSNET
├── PSENET
├── base.py // convolution and BN operator
├── dice_loss.py // calculate PSENet loss value
├── etsnet.py // Subnet in PSENet
├── psenet.py // Subnet in PSENet
├── fpn.py // Subnet in PSENet
├── __init__.py
├── pse // Subnet in PSENet
@ -434,7 +434,7 @@ dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
config.INFERENCE = False
net = ETSNet(config)
net = PSENet(config)
net = net.set_train()
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)

View File

@ -176,7 +176,7 @@ bash scripts/run_eval_ascend.sh
## 脚本和样例代码
```path
└── PSENet
└── psenet
├── export.py // mindir转换脚本
├── mindspore_hub_conf.py // 网络模型
├─postprogress.py # 310推理后处理脚本
@ -193,10 +193,10 @@ bash scripts/run_eval_ascend.sh
├──local_adapter.py # 设备相关信息
├──moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝)
├── dataset.py // 创建数据集
├── ETSNET
├── PSENET
├── base.py // 卷积和BN算子
├── dice_loss.py // 计算PSENet损耗值
├── etsnet.py // PSENet中的子网
├── psenet.py // PSENet中的子网
├── fpn.py // PSENet中的子网
├── __init__.py
├── pse // PSENet中的子网
@ -371,7 +371,7 @@ dataset = dataset.create_dataset(cfg.data_path, 1, False)
# 定义模型
config.INFERENCE = False
net = ETSNet(config)
net = PSENet(config)
net = net.set_train()
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)

View File

@ -20,7 +20,7 @@ import numpy as np
import mindspore as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.model_utils.config import config
from src.ETSNET.etsnet import ETSNet
from src.PSENET.psenet import PSENet
from src.model_utils.moxing_adapter import moxing_wrapper
@ -37,7 +37,7 @@ if config.device_target == "Ascend":
@moxing_wrapper(pre_process=modelarts_pre_process)
def model_export():
net = ETSNet(config)
net = PSENet(config)
param_dict = load_checkpoint(config.ckpt)
load_param_into_net(net, param_dict)

View File

@ -21,7 +21,7 @@ from functools import reduce
import numpy as np
import cv2
from src.model_utils.config import config
from src.ETSNET.pse import pse
from src.PSENET.pse import pse
def sort_to_clockwise(points):

View File

@ -13,7 +13,8 @@
# limitations under the License.
# ============================================================================
CXXFLAGS = -std=c++11 -O3
mindspore_home = ../../../../../../../
CXXFLAGS = -I include -I ${mindspore_home} -std=c++11 -O3
CXX_SOURCES = adaptor.cpp
opencv_home = ${OPENCV_HOME}
OPENCV = -I$(opencv_home)/include -L$(opencv_home)/lib64 -lopencv_superres -lopencv_ml -lopencv_objdetect \

View File

@ -25,7 +25,8 @@
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "./adaptor.h"
#include "model_zoo/official/cv/psenet/src/PSENET/pse/adaptor.h"
using std::vector;
using std::queue;

View File

@ -24,9 +24,9 @@ from .resnet50 import ResNet, ResidualBlock
from .fpn import FPN
class ETSNet(nn.Cell):
class PSENet(nn.Cell):
def __init__(self, config):
super(ETSNet, self).__init__()
super(PESNet, self).__init__()
self.kernel_num = config.KERNEL_NUM
self.inference = config.INFERENCE
if config.INFERENCE:
@ -65,7 +65,7 @@ class ETSNet(nn.Cell):
self.greater = P.Greater()
self.logic_and = P.LogicalAnd()
print('ETSNet initialized!')
print('PSENet initialized!')
def construct(self, x):
c2, c3, c4, c5 = self.feature_extractor(x)

View File

@ -25,7 +25,7 @@ from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import test_dataset_creator
from src.ETSNET.etsnet import ETSNet
from src.PSENET.psenet import PSENet
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
@ -108,7 +108,7 @@ def test():
ds = test_dataset_creator()
config.INFERENCE = True
net = ETSNet(config)
net = PSENet(config)
print(config.ckpt)
param_dict = load_checkpoint(config.ckpt)
load_param_into_net(net, param_dict)

View File

@ -25,8 +25,8 @@ from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.dataset import train_dataset_creator
from src.ETSNET.etsnet import ETSNet
from src.ETSNET.dice_loss import DiceLoss
from src.PSENET.psenet import PSENet
from src.PSENET.dice_loss import DiceLoss
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
from src.lr_schedule import dynamic_lr
from src.model_utils.config import config
@ -95,7 +95,7 @@ def train():
print('Create dataset done!')
config.INFERENCE = False
net = ETSNet(config)
net = PSENet(config)
net = net.set_train()
if config.pre_trained: