forked from mindspore-Ecosystem/mindspore
!22826 modify etsnet to psenet
Merge pull request !22826 from anzhengqi/update-psenet
This commit is contained in:
commit
a7e34a0526
|
@ -176,7 +176,7 @@ bash scripts/run_eval_ascend.sh
|
|||
## [Script and Sample Code](#contents)
|
||||
|
||||
```path
|
||||
└── PSENet
|
||||
└── psenet
|
||||
├── export.py // export mindir file
|
||||
├── postprocess.py // 310 Inference post-processing script
|
||||
├── __init__.py
|
||||
|
@ -196,10 +196,10 @@ bash scripts/run_eval_ascend.sh
|
|||
├──device_adapter.py // Device Config
|
||||
├──local_adapter.py // local device config
|
||||
├── dataset.py // creating dataset
|
||||
├── ETSNET
|
||||
├── PSENET
|
||||
├── base.py // convolution and BN operator
|
||||
├── dice_loss.py // calculate PSENet loss value
|
||||
├── etsnet.py // Subnet in PSENet
|
||||
├── psenet.py // Subnet in PSENet
|
||||
├── fpn.py // Subnet in PSENet
|
||||
├── __init__.py
|
||||
├── pse // Subnet in PSENet
|
||||
|
@ -434,7 +434,7 @@ dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
|||
|
||||
# Define model
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = PSENet(config)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(args.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
|
|
@ -176,7 +176,7 @@ bash scripts/run_eval_ascend.sh
|
|||
## 脚本和样例代码
|
||||
|
||||
```path
|
||||
└── PSENet
|
||||
└── psenet
|
||||
├── export.py // mindir转换脚本
|
||||
├── mindspore_hub_conf.py // 网络模型
|
||||
├─postprogress.py # 310推理后处理脚本
|
||||
|
@ -193,10 +193,10 @@ bash scripts/run_eval_ascend.sh
|
|||
├──local_adapter.py # 设备相关信息
|
||||
├──moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝)
|
||||
├── dataset.py // 创建数据集
|
||||
├── ETSNET
|
||||
├── PSENET
|
||||
├── base.py // 卷积和BN算子
|
||||
├── dice_loss.py // 计算PSENet损耗值
|
||||
├── etsnet.py // PSENet中的子网
|
||||
├── psenet.py // PSENet中的子网
|
||||
├── fpn.py // PSENet中的子网
|
||||
├── __init__.py
|
||||
├── pse // PSENet中的子网
|
||||
|
@ -371,7 +371,7 @@ dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
|||
|
||||
# 定义模型
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = PSENet(config)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(args.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
import mindspore as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
from src.model_utils.config import config
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.PSENET.psenet import PSENet
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
|
@ -37,7 +37,7 @@ if config.device_target == "Ascend":
|
|||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
net = ETSNet(config)
|
||||
net = PSENet(config)
|
||||
param_dict = load_checkpoint(config.ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from functools import reduce
|
|||
import numpy as np
|
||||
import cv2
|
||||
from src.model_utils.config import config
|
||||
from src.ETSNET.pse import pse
|
||||
from src.PSENET.pse import pse
|
||||
|
||||
|
||||
def sort_to_clockwise(points):
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
CXXFLAGS = -std=c++11 -O3
|
||||
mindspore_home = ../../../../../../../
|
||||
CXXFLAGS = -I include -I ${mindspore_home} -std=c++11 -O3
|
||||
CXX_SOURCES = adaptor.cpp
|
||||
opencv_home = ${OPENCV_HOME}
|
||||
OPENCV = -I$(opencv_home)/include -L$(opencv_home)/lib64 -lopencv_superres -lopencv_ml -lopencv_objdetect \
|
|
@ -25,7 +25,8 @@
|
|||
#include <opencv2/core/core.hpp>
|
||||
#include <opencv2/highgui/highgui.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#include "./adaptor.h"
|
||||
|
||||
#include "model_zoo/official/cv/psenet/src/PSENET/pse/adaptor.h"
|
||||
|
||||
using std::vector;
|
||||
using std::queue;
|
|
@ -24,9 +24,9 @@ from .resnet50 import ResNet, ResidualBlock
|
|||
from .fpn import FPN
|
||||
|
||||
|
||||
class ETSNet(nn.Cell):
|
||||
class PSENet(nn.Cell):
|
||||
def __init__(self, config):
|
||||
super(ETSNet, self).__init__()
|
||||
super(PESNet, self).__init__()
|
||||
self.kernel_num = config.KERNEL_NUM
|
||||
self.inference = config.INFERENCE
|
||||
if config.INFERENCE:
|
||||
|
@ -65,7 +65,7 @@ class ETSNet(nn.Cell):
|
|||
self.greater = P.Greater()
|
||||
self.logic_and = P.LogicalAnd()
|
||||
|
||||
print('ETSNet initialized!')
|
||||
print('PSENet initialized!')
|
||||
|
||||
def construct(self, x):
|
||||
c2, c3, c4, c5 = self.feature_extractor(x)
|
|
@ -25,7 +25,7 @@ from mindspore import Tensor, context
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import test_dataset_creator
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.PSENET.psenet import PSENet
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
@ -108,7 +108,7 @@ def test():
|
|||
ds = test_dataset_creator()
|
||||
|
||||
config.INFERENCE = True
|
||||
net = ETSNet(config)
|
||||
net = PSENet(config)
|
||||
print(config.ckpt)
|
||||
param_dict = load_checkpoint(config.ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
|
|
@ -25,8 +25,8 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import train_dataset_creator
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.dice_loss import DiceLoss
|
||||
from src.PSENET.psenet import PSENet
|
||||
from src.PSENET.dice_loss import DiceLoss
|
||||
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.model_utils.config import config
|
||||
|
@ -95,7 +95,7 @@ def train():
|
|||
print('Create dataset done!')
|
||||
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = PSENet(config)
|
||||
net = net.set_train()
|
||||
|
||||
if config.pre_trained:
|
||||
|
|
Loading…
Reference in New Issue