forked from mindspore-Ecosystem/mindspore
[icnet_7_1_4_910inference]
This commit is contained in:
parent
d32e606ed8
commit
366bab9862
|
@ -0,0 +1,197 @@
|
|||
# Contents
|
||||
|
||||
- [ICNet Description](#ICNet-description)
|
||||
- [Model Architecture](#ICNet-Architeture)
|
||||
- [Dataset](#ICNet-Dataset)
|
||||
- [Environmental Requirements](#Environmental)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Training Results](#training-results)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Evaluation Result](#evaluation-result)
|
||||
- [Model Description](#model-description)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [ICNet Description](#Contents)
|
||||
|
||||
ICNet(Image Cascade Network) propose a full convolution network which incorporates multi-resolution branches under proper label guidance to address the challenge of real-time semantic segmentation.
|
||||
|
||||
[paper](https://arxiv.org/abs/1704.08545)ECCV2018
|
||||
|
||||
# [Model Architecture](#Contents)
|
||||
|
||||
ICNet takes cascade image inputs (i.e., low-, medium- and high resolution images), adopts cascade feature fusion unit and is trained with cascade label guidance.The input image with full resolution (e.g., 1024×2048 in Cityscapes) is downsampled by factors of 2 and 4, forming cascade input to medium- and high-resolution branches.
|
||||
|
||||
# [Dataset](#Content)
|
||||
|
||||
used Dataset :[Cityscape Dataset Website](https://www.cityscapes-dataset.com/)
|
||||
|
||||
It contains 5,000 finely annotated images split into training, validation and testing sets with 2,975, 500, and 1,525 images respectively.
|
||||
|
||||
# [Environmental requirements](#Contents)
|
||||
|
||||
- Hardware :(Ascend)
|
||||
- Prepare ascend processor to build hardware environment
|
||||
- frame:
|
||||
- [Mindspore](https://www.mindspore.cn/install)
|
||||
- For details, please refer to the following resources:
|
||||
- [MindSpore course](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# [Scription Description](#Content)
|
||||
|
||||
## Script and Sample Code
|
||||
|
||||
```python
|
||||
.
|
||||
└─ICNet
|
||||
├─configs
|
||||
├─icnet.yaml # config file
|
||||
├─models
|
||||
├─base_models
|
||||
├─resnt50_v1.py # used resnet50
|
||||
├─__init__.py
|
||||
├─icnet.py # validation network
|
||||
├─icnet_dc.py # training network
|
||||
├─scripts
|
||||
├─run_distribute_train8p.sh # Multi card distributed training in ascend
|
||||
├─run_eval.sh # validation script
|
||||
├─utils
|
||||
├─__init__.py
|
||||
├─logger.py # logger
|
||||
├─loss.py # loss
|
||||
├─losses.py # SoftmaxCrossEntropyLoss
|
||||
├─lr_scheduler.py # lr
|
||||
└─metric.py # metric
|
||||
├─eval.py # validation
|
||||
├─train.py # train
|
||||
└─visualize.py # inference visualization
|
||||
```
|
||||
|
||||
## Script Parameters
|
||||
|
||||
Set script parameters in configs/icnet.yaml .
|
||||
|
||||
### Model
|
||||
|
||||
```bash
|
||||
name: "icnet"
|
||||
backbone: "resnet50"
|
||||
base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0]
|
||||
crop_size: 960 # end of augmentation, crop to training
|
||||
```
|
||||
|
||||
### Optimizer
|
||||
|
||||
```bash
|
||||
init_lr: 0.02
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
train_batch_size_percard: 4
|
||||
valid_batch_size: 1
|
||||
cityscapes_root: "/data/cityscapes/"
|
||||
epochs: 160
|
||||
val_epoch: 1 # run validation every val-epoch
|
||||
ckpt_dir: "./ckpt/" # ckpt and training log will be saved here
|
||||
mindrecord_dir: '/root/bigpingping/mindrecord'
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
```
|
||||
|
||||
### Valid
|
||||
|
||||
```bash
|
||||
ckpt_path: "" # set the pretrained model path correctly
|
||||
```
|
||||
|
||||
## Training Process
|
||||
|
||||
### Distributed Training
|
||||
|
||||
- Run distributed train in ascend processor environment
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PROJECT_PATH]
|
||||
```
|
||||
|
||||
- Notes:
|
||||
|
||||
The hccl.json file specified by [RANK_TABLE_FILE] is used when running distributed tasks. You can use [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate this file.
|
||||
|
||||
### Training Result
|
||||
|
||||
The training results will be saved in the example path, The folder name starts with "ICNet-".You can find the checkpoint file and similar results below in LOG(0-7)/log.txt.
|
||||
|
||||
```bash
|
||||
# distributed training result(8p)
|
||||
epoch: 1 step: 93, loss is 0.5659234
|
||||
epoch time: 672111.671 ms, per step time: 7227.007 ms
|
||||
epoch: 2 step: 93, loss is 1.0220546
|
||||
epoch time: 66850.354 ms, per step time: 718.821 ms
|
||||
epoch: 3 step: 93, loss is 0.49694514
|
||||
epoch time: 70490.510 ms, per step time: 757.962 ms
|
||||
epoch: 4 step: 93, loss is 0.74727297
|
||||
epoch time: 73657.396 ms, per step time: 792.015 ms
|
||||
epoch: 5 step: 93, loss is 0.45953503
|
||||
epoch time: 97117.785 ms, per step time: 1044.277 ms
|
||||
```
|
||||
|
||||
## Evaluation Process
|
||||
|
||||
### Evaluation
|
||||
|
||||
Check the checkpoint path used for evaluation before running the following command.
|
||||
|
||||
```shell
|
||||
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH]
|
||||
```
|
||||
|
||||
### Evaluation Result
|
||||
|
||||
The results at eval/log were as follows:
|
||||
|
||||
```bash
|
||||
Found 500 images in the folder /data/cityscapes/leftImg8bit/val
|
||||
pretrained....
|
||||
2021-06-01 19:03:54,570 semantic_segmentation INFO: Start validation, Total sample: 500
|
||||
avgmiou 0.69962835
|
||||
avg_pixacc 0.94285786
|
||||
avgtime 0.19648232793807982
|
||||
````
|
||||
|
||||
# [Model Description](#Content)
|
||||
|
||||
## Performance
|
||||
|
||||
### Training Performance
|
||||
|
||||
|Parameter | MaskRCNN |
|
||||
| ------------------- | --------------------------------------------------------- |
|
||||
|resources | Ascend 910;CPU 2.60GHz, 192core;memory:755G |
|
||||
|Upload date |2021.6.1 |
|
||||
|mindspore version |mindspore1.2.0 |
|
||||
|training parameter |epoch=160,batch_size=32 |
|
||||
|optimizer |SGD optimizer,momentum=0.9,weight_decay=0.0001 |
|
||||
|loss function |SoftmaxCrossEntropyLoss |
|
||||
|training speed | epoch time:285693.557 ms per step time :42.961 ms |
|
||||
|total time |about 5 hours |
|
||||
|Script URL | |
|
||||
|Random number seed |set_seed = 1234 |
|
||||
|
||||
# [Description of Random Situation](#Content)
|
||||
|
||||
The seed in the `create_icnet_dataset` function is set in `cityscapes_mindrecord.py`, and the random seed in `train.py` is also used for weight initialization.
|
||||
|
||||
# [ModelZoo Homepage](#Content)
|
||||
|
||||
Please visit the official website [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluate mIou and Pixacc"""
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
import yaml
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import mindspore.ops as ops
|
||||
from mindspore import load_param_into_net
|
||||
from mindspore import load_checkpoint
|
||||
from mindspore import Tensor
|
||||
import mindspore.dataset.vision.py_transforms as transforms
|
||||
|
||||
parser = argparse.ArgumentParser(description="ICNet Evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="/data/cityscapes/", help="dataset path")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="/root/ICNet/ckpt/ICNet-160_93_699.ckpt",
|
||||
help="checkpoint_path, default67.7")
|
||||
parser.add_argument("--project_path", type=str, default='/root/ICNet/',
|
||||
help="project_path,default is /root/ICNet/")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""evaluate"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.cfg = config
|
||||
|
||||
# get valid dataset images and targets
|
||||
self.image_paths, self.mask_paths = _get_city_pairs(config["train"]["cityscapes_root"], "val")
|
||||
|
||||
# create network
|
||||
self.model = ICNet(nclass=19, backbone='resnet50', istraining=False)
|
||||
|
||||
# load ckpt
|
||||
ckpt_file_name = args_opt.checkpoint_path
|
||||
param_dict = load_checkpoint(ckpt_file_name)
|
||||
load_param_into_net(self.model, param_dict)
|
||||
|
||||
# evaluation metrics
|
||||
self.metric = SegmentationMetric(19)
|
||||
|
||||
def eval(self):
|
||||
"""evaluate"""
|
||||
self.metric.reset()
|
||||
model = self.model
|
||||
model = model.set_train(False)
|
||||
|
||||
logger.info("Start validation, Total sample: {:d}".format(len(self.image_paths)))
|
||||
list_time = []
|
||||
|
||||
for i in range(len(self.image_paths)):
|
||||
image = Image.open(self.image_paths[i]).convert('RGB') # image shape: (W,H,3)
|
||||
mask = Image.open(self.mask_paths[i]) # mask shape: (W,H)
|
||||
|
||||
image = self._img_transform(image) # image shape: (3,H,W) [0,1]
|
||||
mask = self._mask_transform(mask) # mask shape: (H,w)
|
||||
|
||||
image = Tensor(image)
|
||||
|
||||
expand_dims = ops.ExpandDims()
|
||||
image = expand_dims(image, 0)
|
||||
|
||||
start_time = time.time()
|
||||
output = model(image)
|
||||
end_time = time.time()
|
||||
step_time = end_time - start_time
|
||||
|
||||
expand_dims = ops.ExpandDims()
|
||||
mask = expand_dims(mask, 0)
|
||||
self.metric.update(output, mask)
|
||||
list_time.append(step_time)
|
||||
|
||||
pixAcc, mIoU = self.metric.get()
|
||||
|
||||
average_time = sum(list_time) / len(list_time)
|
||||
|
||||
print("avgmiou", mIoU)
|
||||
print("avg_pixacc", pixAcc)
|
||||
print("avgtime", average_time)
|
||||
|
||||
def _img_transform(self, image):
|
||||
"""img_transform"""
|
||||
to_tensor = transforms.ToTensor()
|
||||
normalize = transforms.Normalize([.485, .456, .406], [.229, .224, .225])
|
||||
image = to_tensor(image)
|
||||
image = normalize(image)
|
||||
return image
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
mask = self._class_to_index(np.array(mask).astype('int32'))
|
||||
return Tensor(np.array(mask).astype('int32')) # torch.LongTensor
|
||||
|
||||
def _class_to_index(self, mask):
|
||||
"""assert the value"""
|
||||
values = np.unique(mask)
|
||||
self._key = np.array([-1, -1, -1, -1, -1, -1,
|
||||
-1, -1, 0, 1, -1, -1,
|
||||
2, 3, 4, -1, -1, -1,
|
||||
5, -1, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15,
|
||||
-1, -1, 16, 17, 18])
|
||||
self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
|
||||
for value in values:
|
||||
assert value in self._mapping
|
||||
# Get the index of each pixel value in the mask corresponding to _mapping
|
||||
index = np.digitize(mask.ravel(), self._mapping, right=True)
|
||||
# According to the above index index, according to _key, the corresponding mask image is obtained
|
||||
return self._key[index].reshape(mask.shape)
|
||||
|
||||
|
||||
def _get_city_pairs(folder, split='train'):
|
||||
"""get dataset img_mask_path_pairs"""
|
||||
|
||||
def get_path_pairs(image_folder, mask_folder):
|
||||
img_paths = []
|
||||
mask_paths = []
|
||||
for root, _, files in os.walk(image_folder):
|
||||
for filename in files:
|
||||
if filename.endswith('.png'):
|
||||
imgpath = os.path.join(root, filename)
|
||||
foldername = os.path.basename(os.path.dirname(imgpath))
|
||||
maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
|
||||
maskpath = os.path.join(mask_folder, foldername, maskname)
|
||||
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
|
||||
img_paths.append(imgpath)
|
||||
mask_paths.append(maskpath)
|
||||
else:
|
||||
print('cannot find the mask or image:', imgpath, maskpath)
|
||||
print('Found {} images in the folder {}'.format(len(img_paths), image_folder))
|
||||
return img_paths, mask_paths
|
||||
|
||||
if split in ('train', 'val', 'test'):
|
||||
# "./Cityscapes/leftImg8bit/train" or "./Cityscapes/leftImg8bit/val"
|
||||
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
||||
# "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
|
||||
mask_folder = os.path.join(folder, 'gtFine/' + split)
|
||||
|
||||
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
||||
return img_paths, mask_paths
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(args_opt.project_path)
|
||||
from src.models import ICNet
|
||||
from src.metric import SegmentationMetric
|
||||
from src.logger import SetupLogger
|
||||
# Set config file
|
||||
config_path = args_opt.project_path + "/src/model_utils/icnet.yaml"
|
||||
with open(config_path, "r") as yaml_file:
|
||||
cfg = yaml.load(yaml_file.read())
|
||||
logger = SetupLogger(name="semantic_segmentation",
|
||||
save_dir=cfg["train"]["ckpt_dir"],
|
||||
distributed_rank=0,
|
||||
filename='{}_{}_evaluate_log.txt'.format(cfg["model"]["name"], cfg["model"]["backbone"]))
|
||||
|
||||
evaluator = Evaluator(cfg)
|
||||
evaluator.eval()
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
import yaml
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
import mindspore.common.dtype as dtype
|
||||
from src.models import ICNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='maskrcnn export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="icnet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
||||
parser.add_argument("--project_path", type=str, default='/root/ICNet/',
|
||||
help="project_path,default is /root/ICNet/")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_path = args.project_path + "src/model_utils/icnet.yaml"
|
||||
with open(config_path, "r") as yaml_file:
|
||||
cfg = yaml.load(yaml_file.read())
|
||||
net = ICNet()
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
img = Tensor(np.ones([args.batch_size, 3, cfg['model']["base_size"], cfg['model']["base_size"]*2]), dtype.float32)
|
||||
|
||||
export(net, img, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PROJECT_PATH]"
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PROJECT_PATH]"
|
||||
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /root/ICNet/"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export RANK_SIZE=8
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_SIZE=8
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python3 train.py --project_path=$PATH1 > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$5
|
||||
export RANK_SIZE=1
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../eval.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
|
||||
cd ..
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Prepare Cityscapes dataset"""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL import ImageOps
|
||||
from PIL import ImageFilter
|
||||
import mindspore.dataset as de
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset.vision.py_transforms as transforms
|
||||
import mindspore.dataset.transforms.py_transforms as tc
|
||||
|
||||
|
||||
def _get_city_pairs(folder, split='train'):
|
||||
"""Return two path arrays of data set img and mask"""
|
||||
def get_path_pairs(image_folder, masks_folder):
|
||||
image_paths = []
|
||||
masks_paths = []
|
||||
for root, _, files in os.walk(image_folder):
|
||||
for filename in files:
|
||||
if filename.endswith('.png'):
|
||||
imgpath = os.path.join(root, filename)
|
||||
foldername = os.path.basename(os.path.dirname(imgpath))
|
||||
maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
|
||||
maskpath = os.path.join(masks_folder, foldername, maskname)
|
||||
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
|
||||
image_paths.append(imgpath)
|
||||
masks_paths.append(maskpath)
|
||||
else:
|
||||
print('cannot find the mask or image:', imgpath, maskpath)
|
||||
print('Found {} images in the folder {}'.format(len(image_paths), image_folder))
|
||||
return image_paths, masks_paths
|
||||
|
||||
if split in ('train', 'val'):
|
||||
# "./Cityscapes/leftImg8bit/train" or "./Cityscapes/leftImg8bit/val"
|
||||
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
||||
# "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
|
||||
mask_folder = os.path.join(folder, 'gtFine/' + split)
|
||||
|
||||
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
||||
return img_paths, mask_paths
|
||||
|
||||
|
||||
def _sync_transform(img, mask):
|
||||
"""img and mask augmentation"""
|
||||
a = random.Random()
|
||||
a.seed(1234)
|
||||
base_size = 1024
|
||||
crop_size = 960
|
||||
|
||||
# random mirror
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
crop_size = crop_size
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(int(base_size * 0.5), int(base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < crop_size:
|
||||
padh = crop_size - oh if oh < crop_size else 0
|
||||
padw = crop_size - ow if ow < crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - crop_size)
|
||||
y1 = random.randint(0, h - crop_size)
|
||||
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
# gaussian blur as in PSP
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
# final transform
|
||||
output = _img_mask_transform(img, mask)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _class_to_index(mask):
|
||||
"""class to index"""
|
||||
# reference: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
|
||||
_key = np.array([-1, -1, -1, -1, -1, -1,
|
||||
-1, -1, 0, 1, -1, -1,
|
||||
2, 3, 4, -1, -1, -1,
|
||||
5, -1, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15,
|
||||
-1, -1, 16, 17, 18])
|
||||
# [-1, ..., 33]
|
||||
_mapping = np.array(range(-1, len(_key) - 1)).astype('int32')
|
||||
|
||||
# assert the value
|
||||
values = np.unique(mask)
|
||||
for value in values:
|
||||
assert value in _mapping
|
||||
# Get the index of each pixel value in the mask corresponding to _mapping
|
||||
index = np.digitize(mask.ravel(), _mapping, right=True)
|
||||
# According to the above index, according to _key, get the corresponding
|
||||
return _key[index].reshape(mask.shape)
|
||||
|
||||
|
||||
def _img_transform(img):
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def _mask_transform(mask):
|
||||
target = _class_to_index(np.array(mask).astype('int32'))
|
||||
return np.array(target).astype('int32')
|
||||
|
||||
|
||||
def _img_mask_transform(img, mask):
|
||||
"""img and mask transform"""
|
||||
input_transform = tc.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
|
||||
img = _img_transform(img)
|
||||
mask = _mask_transform(mask)
|
||||
img = input_transform(img)
|
||||
|
||||
img = np.array(img).astype(np.float32)
|
||||
mask = np.array(mask).astype(np.float32)
|
||||
|
||||
return (img, mask)
|
||||
|
||||
|
||||
def data_to_mindrecord_img(prefix='cityscapes.mindrecord', file_num=1,
|
||||
root='/data/cityscapes/', split='train'):
|
||||
"""to mindrecord"""
|
||||
mindrecord_dir = '/data/Mindrecord_cityscapes/'
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix + "2975-2")
|
||||
|
||||
writter = FileWriter(mindrecord_path, file_num)
|
||||
|
||||
img_paths, mask_paths = _get_city_pairs(root, split)
|
||||
|
||||
cityscapes_json = {
|
||||
"images": {"type": "int32", "shape": [1024, 2048, 3]},
|
||||
"mask": {"type": "int32", "shape": [1024, 2048]},
|
||||
}
|
||||
|
||||
writter.add_schema(cityscapes_json, "cityscapes_json")
|
||||
|
||||
images_files_num = len(img_paths)
|
||||
for index in range(images_files_num):
|
||||
img = Image.open(img_paths[index]).convert('RGB')
|
||||
img = np.array(img, dtype=np.int32)
|
||||
|
||||
mask = Image.open(mask_paths[index])
|
||||
mask = np.array(mask, dtype=np.int32)
|
||||
|
||||
row = {"images": img, "mask": mask}
|
||||
if (index + 1) % 10 == 0:
|
||||
print("writing {}/{} into mindrecord".format(index + 1, images_files_num))
|
||||
writter.write_raw_data([row])
|
||||
writter.commit()
|
||||
|
||||
|
||||
def get_Image_crop_nor(img, mask):
|
||||
image = np.uint8(img)
|
||||
mask = np.uint8(mask)
|
||||
image = Image.fromarray(image)
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
output = _sync_transform(image, mask)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def create_icnet_dataset(mindrecord_file, batch_size=16, device_num=1, rank_id=0):
|
||||
"""create dataset for training"""
|
||||
a = random.Random()
|
||||
a.seed(1234)
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["images", "mask"],
|
||||
num_shards=device_num, shard_id=rank_id, shuffle=True)
|
||||
ds = ds.map(operations=get_Image_crop_nor, input_columns=["images", "mask"], output_columns=["image", "masks"])
|
||||
|
||||
ds = ds.batch(batch_size=batch_size, drop_remainder=False)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_to_mindrecord_img()
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Logger"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
__all__ = ['SetupLogger']
|
||||
|
||||
# reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py
|
||||
def SetupLogger(name, save_dir, distributed_rank, filename="log.txt", mode='w'):
|
||||
"""setupLogger"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# don't log results for the non-master process
|
||||
if distributed_rank > 0:
|
||||
return logger
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if save_dir:
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom losses."""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from src.losses import SoftmaxCrossEntropyLoss
|
||||
|
||||
__all__ = ['ICNetLoss']
|
||||
|
||||
|
||||
class ICNetLoss(nn.Cell):
|
||||
"""Cross Entropy Loss for ICNet"""
|
||||
|
||||
def __init__(self, aux_weight=0.4, ignore_index=-1):
|
||||
super(ICNetLoss, self).__init__()
|
||||
self.aux_weight = aux_weight
|
||||
self.ignore_index = ignore_index
|
||||
self.sparse = True
|
||||
self.base_loss = SoftmaxCrossEntropyLoss(num_cls=19, ignore_label=-1)
|
||||
self.resize_bilinear = nn.ResizeBilinear() # 输入必须为4D
|
||||
|
||||
def construct(self, *inputs):
|
||||
"""construct"""
|
||||
preds, target = inputs
|
||||
|
||||
pred = preds[0]
|
||||
pred_sub4 = preds[1]
|
||||
pred_sub8 = preds[2]
|
||||
pred_sub16 = preds[3]
|
||||
|
||||
# [batch, H, W] -> [batch, 1, H, W]
|
||||
expand_dims = ops.ExpandDims()
|
||||
if target.shape[0] == 720 or target.shape[0] == 1024:
|
||||
target = expand_dims(target, 0).astype(ms.dtype.float32)
|
||||
target = expand_dims(target, 0).astype(ms.dtype.float32)
|
||||
else:
|
||||
target = expand_dims(target, 1).astype(ms.dtype.float32)
|
||||
|
||||
h, w = pred.shape[2:]
|
||||
|
||||
target_sub4 = self.resize_bilinear(target, size=(h / 4, w / 4)).squeeze(1)
|
||||
|
||||
target_sub8 = self.resize_bilinear(target, size=(h / 8, w / 8)).squeeze(1)
|
||||
|
||||
target_sub16 = self.resize_bilinear(target, size=(h / 16, w / 16)).squeeze(1)
|
||||
|
||||
loss1 = self.base_loss(pred_sub4, target_sub4)
|
||||
loss2 = self.base_loss(pred_sub8, target_sub8)
|
||||
loss3 = self.base_loss(pred_sub16, target_sub16)
|
||||
|
||||
return loss1 + loss2 * self.aux_weight + loss3 * self.aux_weight
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""loss unit"""
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLoss(nn.Cell):
|
||||
"""SoftmaxCrossEntropyLoss"""
|
||||
def __init__(self, num_cls=19, ignore_label=-1):
|
||||
super(SoftmaxCrossEntropyLoss, self).__init__()
|
||||
self.one_hot = P.OneHot(axis=-1)
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.num_cls = num_cls
|
||||
self.ignore_label = ignore_label
|
||||
self.mul = P.Mul()
|
||||
self.sum = P.ReduceSum(False)
|
||||
self.div = P.RealDiv()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
"""construct"""
|
||||
labels_int = self.cast(labels, mstype.int32)
|
||||
labels_int = self.reshape(labels_int, (-1,))
|
||||
logits_ = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits_ = self.reshape(logits_, (-1, self.num_cls))
|
||||
weights = self.not_equal(labels_int, self.ignore_label)
|
||||
weights = self.cast(weights, mstype.float32)
|
||||
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
|
||||
loss = self.ce(logits_, one_hot_labels)
|
||||
loss = self.mul(weights, loss)
|
||||
loss = self.div(self.sum(loss), self.sum(weights))
|
||||
return loss
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Popular Learning Rate Schedulers"""
|
||||
from __future__ import division
|
||||
|
||||
|
||||
def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
|
||||
for i in range(total_steps):
|
||||
step_ = min(i, decay_steps)
|
||||
yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation Metrics for Semantic Segmentation"""
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as dtype
|
||||
|
||||
__all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union']
|
||||
|
||||
|
||||
class SegmentationMetric:
|
||||
"""Computes pixAcc and mIoU metric scores"""
|
||||
|
||||
def __init__(self, nclass):
|
||||
super(SegmentationMetric, self).__init__()
|
||||
self.nclass = nclass
|
||||
self.reset()
|
||||
|
||||
def update(self, pred, label):
|
||||
"""Updates the internal evaluation result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
labels : 'NumpyArray' or list of `NumpyArray`
|
||||
The labels of the data.
|
||||
preds : 'NumpyArray' or list of `NumpyArray`
|
||||
Predicted values.
|
||||
"""
|
||||
correct, labeled = batch_pix_accuracy(pred, label)
|
||||
inter, union = batch_intersection_union(pred, label, self.nclass)
|
||||
|
||||
self.total_correct += correct
|
||||
self.total_label += labeled
|
||||
|
||||
self.total_inter += inter
|
||||
self.total_union += union
|
||||
|
||||
def get(self):
|
||||
"""Gets the current evaluation result.
|
||||
|
||||
Returns
|
||||
-------
|
||||
metrics : tuple of float
|
||||
pixAcc and mIoU
|
||||
"""
|
||||
mean = ops.ReduceMean(keep_dims=False)
|
||||
pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove c.spacing(1)
|
||||
IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
|
||||
|
||||
mIoU = mean(IoU, axis=0)
|
||||
|
||||
return pixAcc, mIoU
|
||||
|
||||
def reset(self):
|
||||
"""Resets the internal evaluation result to initial state."""
|
||||
zeros = ops.Zeros()
|
||||
self.total_inter = zeros(self.nclass, dtype.float32)
|
||||
self.total_union = zeros(self.nclass, dtype.float32)
|
||||
self.total_correct = 0
|
||||
self.total_label = 0
|
||||
|
||||
|
||||
def batch_pix_accuracy(output, target):
|
||||
"""PixAcc"""
|
||||
|
||||
predict = ops.Argmax(output_type=dtype.int32, axis=1)(output) + 1
|
||||
# (1,19, 1024,2048)-->(1, 1024,2048)
|
||||
target = target + 1
|
||||
|
||||
typetrue = dtype.float32
|
||||
cast = ops.Cast()
|
||||
sumtarget = ops.ReduceSum()
|
||||
sumcorrect = ops.ReduceSum()
|
||||
|
||||
labeled = cast(target > 0, typetrue)
|
||||
pixel_labeled = sumtarget(labeled) # sum of pixels without 0
|
||||
|
||||
pixel_correct = sumcorrect(cast(predict == target, typetrue) * cast(target > 0, typetrue)) # 标记正确的像素和
|
||||
|
||||
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
|
||||
return pixel_correct, pixel_labeled
|
||||
|
||||
|
||||
def batch_intersection_union(output, target, nclass):
|
||||
"""mIoU"""
|
||||
# inputs are numpy array, output 4D, target 3D
|
||||
predict = ops.Argmax(output_type=dtype.int32, axis=1)(output) + 1 # [N,H,W]
|
||||
target = target.astype(dtype.float32) + 1 # [N,H,W]
|
||||
|
||||
typetrue = dtype.float32
|
||||
cast = ops.Cast()
|
||||
predict = cast(predict, typetrue) * cast(target > 0, typetrue)
|
||||
intersection = cast(predict, typetrue) * cast(predict == target, typetrue)
|
||||
# areas of intersection and union
|
||||
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
|
||||
|
||||
Range = Tensor([0.0, 20.0], dtype.float32)
|
||||
hist = ops.HistogramFixedWidth(nclass + 1)
|
||||
area_inter = hist(intersection, Range)
|
||||
area_pred = hist(predict, Range)
|
||||
area_lab = hist(target, Range)
|
||||
|
||||
area_union = area_pred + area_lab - area_inter
|
||||
|
||||
area_inter = area_inter[1:]
|
||||
area_union = area_union[1:]
|
||||
Sum = ops.ReduceSum()
|
||||
assert Sum(cast(area_inter > area_union, typetrue)) == 0, "Intersection area should be smaller than Union area"
|
||||
return cast(area_inter, typetrue), cast(area_union, typetrue)
|
|
@ -0,0 +1,28 @@
|
|||
### 1.Model
|
||||
model:
|
||||
name: "icnet"
|
||||
backbone: "resnet50"
|
||||
base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0]
|
||||
crop_size: 960 # end of augmentation, crop to training
|
||||
|
||||
### 2.Optimizer
|
||||
optimizer:
|
||||
init_lr: 0.02
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
|
||||
### 3.Training
|
||||
train:
|
||||
train_batch_size_percard: 4
|
||||
valid_batch_size: 1
|
||||
cityscapes_root: "/data/cityscapes/"
|
||||
epochs: 160
|
||||
val_epoch: 1 # run validation every val-epoch
|
||||
ckpt_dir: "./ckpt/" # ckpt and training log will be saved here
|
||||
mindrecord_dir: '/root/mindrecord'
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
|
||||
### 4.Valid
|
||||
test:
|
||||
ckpt_path: "" # set the pretrained model path correctly
|
|
@ -0,0 +1,3 @@
|
|||
"""__init__"""
|
||||
from .icnet import ICNet
|
||||
from .icnet_dc import ICNetdc
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Image Cascade Network"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context
|
||||
from src.loss import ICNetLoss
|
||||
from src.models.resnet50_v1 import get_resnet50v1b
|
||||
|
||||
__all__ = ['ICNet']
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
|
||||
|
||||
class ICNet(nn.Cell):
|
||||
"""Image Cascade Network"""
|
||||
|
||||
def __init__(self, nclass=19, backbone='resnet50', pretrained_base=True, istraining=True):
|
||||
super(ICNet, self).__init__()
|
||||
self.conv_sub1 = nn.SequentialCell(
|
||||
_ConvBNReLU(3, 32, 3, 2),
|
||||
_ConvBNReLU(32, 32, 3, 2),
|
||||
_ConvBNReLU(32, 64, 3, 2)
|
||||
)
|
||||
self.istraining = istraining
|
||||
self.ppm = PyramidPoolingModule()
|
||||
|
||||
self.backbone = SegBaseModel()
|
||||
|
||||
self.head = _ICHead(nclass)
|
||||
|
||||
self.loss = ICNetLoss()
|
||||
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
self.__setattr__('exclusive', ['conv_sub1', 'head'])
|
||||
|
||||
def construct(self, x):
|
||||
"""ICNet_construct"""
|
||||
if x.shape[0] != 1:
|
||||
x = x.squeeze()
|
||||
# sub 1
|
||||
x_sub1 = self.conv_sub1(x)
|
||||
|
||||
h, w = x.shape[2:]
|
||||
# sub 2
|
||||
x_sub2 = self.resize_bilinear(x, size=(h / 2, w / 2))
|
||||
_, x_sub2, _, _ = self.backbone(x_sub2)
|
||||
|
||||
# sub 4
|
||||
_, _, _, x_sub4 = self.backbone(x)
|
||||
# add PyramidPoolingModule
|
||||
x_sub4 = self.ppm(x_sub4)
|
||||
|
||||
output = self.head(x_sub1, x_sub2, x_sub4)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class PyramidPoolingModule(nn.Cell):
|
||||
"""PPM"""
|
||||
|
||||
def __init__(self, pyramids=None):
|
||||
super(PyramidPoolingModule, self).__init__()
|
||||
self.avgpool = ops.ReduceMean(keep_dims=True)
|
||||
self.pool2 = nn.AvgPool2d(kernel_size=15, stride=15)
|
||||
self.pool3 = nn.AvgPool2d(kernel_size=10, stride=10)
|
||||
self.pool6 = nn.AvgPool2d(kernel_size=5, stride=5)
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
def construct(self, x):
|
||||
"""ppm_construct"""
|
||||
feat = x
|
||||
height, width = x.shape[2:]
|
||||
|
||||
x1 = self.avgpool(x, (2, 3))
|
||||
x1 = self.resize_bilinear(x1, size=(height, width), align_corners=True)
|
||||
feat = feat + x1
|
||||
|
||||
x2 = self.pool2(x)
|
||||
x2 = self.resize_bilinear(x2, size=(height, width), align_corners=True)
|
||||
feat = feat + x2
|
||||
|
||||
x3 = self.pool3(x)
|
||||
x3 = self.resize_bilinear(x3, size=(height, width), align_corners=True)
|
||||
feat = feat + x3
|
||||
|
||||
x6 = self.pool6(x)
|
||||
x6 = self.resize_bilinear(x6, size=(height, width), align_corners=True)
|
||||
feat = feat + x6
|
||||
|
||||
return feat
|
||||
|
||||
|
||||
class _ICHead(nn.Cell):
|
||||
"""Head"""
|
||||
|
||||
def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
|
||||
super(_ICHead, self).__init__()
|
||||
self.cff_12 = CascadeFeatureFusion12(128, 64, 128, nclass, norm_layer, **kwargs)
|
||||
self.cff_24 = CascadeFeatureFusion24(2048, 512, 128, nclass, norm_layer, **kwargs)
|
||||
|
||||
self.conv_cls = nn.Conv2d(128, nclass, 1, has_bias=False)
|
||||
self.outputs = list()
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
def construct(self, x_sub1, x_sub2, x_sub4):
|
||||
"""Head_construct"""
|
||||
outputs = self.outputs
|
||||
x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
|
||||
|
||||
# x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
|
||||
x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
|
||||
|
||||
h1, w1 = x_cff_12.shape[2:]
|
||||
up_x2 = self.resize_bilinear(x_cff_12, size=(h1 * 2, w1 * 2),
|
||||
align_corners=True)
|
||||
up_x2 = self.conv_cls(up_x2)
|
||||
h2, w2 = up_x2.shape[2:]
|
||||
|
||||
up_x8 = self.resize_bilinear(up_x2, size=(h2 * 4, w2 * 4),
|
||||
align_corners=True) # scale_factor=4,
|
||||
outputs.append(up_x8)
|
||||
outputs.append(up_x2)
|
||||
outputs.append(x_12_cls)
|
||||
outputs.append(x_24_cls)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class _ConvBNReLU(nn.Cell):
|
||||
"""ConvBNRelu"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, dilation=1,
|
||||
groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs):
|
||||
super(_ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
dilation=dilation,
|
||||
group=1, has_bias=False)
|
||||
self.bn = norm_layer(out_channels, momentum=0.1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class CascadeFeatureFusion12(nn.Cell):
|
||||
"""CFF Unit"""
|
||||
|
||||
def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
|
||||
super(CascadeFeatureFusion12, self).__init__()
|
||||
self.conv_low = nn.SequentialCell(
|
||||
nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_high = nn.SequentialCell(
|
||||
nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
self.scalar_cast = ops.ScalarCast()
|
||||
|
||||
self.relu = ms.nn.ReLU()
|
||||
|
||||
def construct(self, x_low, x_high):
|
||||
"""cff_construct"""
|
||||
h, w = x_high.shape[2:]
|
||||
x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
|
||||
x_low = self.conv_low(x_low)
|
||||
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
|
||||
x = self.relu(x)
|
||||
x_low_cls = self.conv_low_cls(x_low)
|
||||
|
||||
return x, x_low_cls
|
||||
|
||||
|
||||
class CascadeFeatureFusion24(nn.Cell):
|
||||
"""CFF Unit"""
|
||||
|
||||
def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
|
||||
super(CascadeFeatureFusion24, self).__init__()
|
||||
self.conv_low = nn.SequentialCell(
|
||||
nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_high = nn.SequentialCell(
|
||||
nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
|
||||
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
self.relu = ms.nn.ReLU()
|
||||
|
||||
def construct(self, x_low, x_high):
|
||||
"""ccf_construct"""
|
||||
h, w = x_high.shape[2:]
|
||||
|
||||
x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
|
||||
x_low = self.conv_low(x_low)
|
||||
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
|
||||
x = self.relu(x)
|
||||
x_low_cls = self.conv_low_cls(x_low)
|
||||
|
||||
return x, x_low_cls
|
||||
|
||||
|
||||
class SegBaseModel(nn.Cell):
|
||||
"""Base Model for Semantic Segmentation"""
|
||||
|
||||
def __init__(self, nclass=19, backbone='resnet50', pretrained_base=True, **kwargs):
|
||||
super(SegBaseModel, self).__init__()
|
||||
self.nclass = nclass
|
||||
if backbone == 'resnet50':
|
||||
self.pretrained = get_resnet50v1b()
|
||||
|
||||
def construct(self, x):
|
||||
"""forwarding pre-trained network"""
|
||||
x = self.pretrained.conv1(x)
|
||||
x = self.pretrained.bn1(x)
|
||||
x = self.pretrained.relu(x)
|
||||
x = self.pretrained.maxpool(x)
|
||||
c1 = self.pretrained.layer1(x)
|
||||
c2 = self.pretrained.layer2(c1)
|
||||
c3 = self.pretrained.layer3(c2)
|
||||
c4 = self.pretrained.layer4(c3)
|
||||
|
||||
return c1, c2, c3, c4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Image Cascade Network"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context
|
||||
from src.loss import ICNetLoss
|
||||
from src.models.resnet50_v1 import get_resnet50v1b
|
||||
|
||||
__all__ = ['ICNetdc']
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
|
||||
|
||||
class ICNetdc(nn.Cell):
|
||||
"""Image Cascade Network"""
|
||||
|
||||
def __init__(self, nclass=19, backbone='resnet50', pretrained_base=True, istraining=True):
|
||||
super(ICNetdc, self).__init__()
|
||||
self.conv_sub1 = nn.SequentialCell(
|
||||
_ConvBNReLU(3, 32, 3, 2),
|
||||
_ConvBNReLU(32, 32, 3, 2),
|
||||
_ConvBNReLU(32, 64, 3, 2)
|
||||
)
|
||||
self.istraining = istraining
|
||||
self.ppm = PyramidPoolingModule()
|
||||
|
||||
self.backbone = SegBaseModel()
|
||||
|
||||
self.head = _ICHead(nclass)
|
||||
|
||||
self.loss = ICNetLoss()
|
||||
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
self.__setattr__('exclusive', ['conv_sub1', 'head'])
|
||||
|
||||
def construct(self, x, y):
|
||||
"""ICNet_construct"""
|
||||
if x.shape[0] != 1:
|
||||
x = x.squeeze()
|
||||
# sub 1
|
||||
x_sub1 = self.conv_sub1(x)
|
||||
|
||||
h, w = x.shape[2:]
|
||||
# sub 2
|
||||
x_sub2 = self.resize_bilinear(x, size=(h / 2, w / 2))
|
||||
_, x_sub2, _, _ = self.backbone(x_sub2)
|
||||
|
||||
# sub 4
|
||||
_, _, _, x_sub4 = self.backbone(x)
|
||||
# add PyramidPoolingModule
|
||||
x_sub4 = self.ppm(x_sub4)
|
||||
|
||||
output = self.head(x_sub1, x_sub2, x_sub4)
|
||||
|
||||
if self.istraining:
|
||||
outputs = self.loss(output, y)
|
||||
else:
|
||||
outputs = output
|
||||
return outputs
|
||||
|
||||
|
||||
class PyramidPoolingModule(nn.Cell):
|
||||
"""PPM"""
|
||||
|
||||
def __init__(self, pyramids=None):
|
||||
super(PyramidPoolingModule, self).__init__()
|
||||
self.avgpool = ops.ReduceMean(keep_dims=True)
|
||||
self.pool2 = nn.AvgPool2d(kernel_size=15, stride=15)
|
||||
self.pool3 = nn.AvgPool2d(kernel_size=10, stride=10)
|
||||
self.pool6 = nn.AvgPool2d(kernel_size=5, stride=5)
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
def construct(self, x):
|
||||
"""ppm_construct"""
|
||||
feat = x
|
||||
height, width = x.shape[2:]
|
||||
|
||||
x1 = self.avgpool(x, (2, 3))
|
||||
x1 = self.resize_bilinear(x1, size=(height, width), align_corners=True)
|
||||
feat = feat + x1
|
||||
|
||||
x2 = self.pool2(x)
|
||||
x2 = self.resize_bilinear(x2, size=(height, width), align_corners=True)
|
||||
feat = feat + x2
|
||||
|
||||
x3 = self.pool3(x)
|
||||
x3 = self.resize_bilinear(x3, size=(height, width), align_corners=True)
|
||||
feat = feat + x3
|
||||
|
||||
x6 = self.pool6(x)
|
||||
x6 = self.resize_bilinear(x6, size=(height, width), align_corners=True)
|
||||
feat = feat + x6
|
||||
|
||||
return feat
|
||||
|
||||
|
||||
class _ICHead(nn.Cell):
|
||||
"""Head"""
|
||||
|
||||
def __init__(self, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
|
||||
super(_ICHead, self).__init__()
|
||||
self.cff_12 = CascadeFeatureFusion12(128, 64, 128, nclass, norm_layer, **kwargs)
|
||||
self.cff_24 = CascadeFeatureFusion24(2048, 512, 128, nclass, norm_layer, **kwargs)
|
||||
|
||||
self.conv_cls = nn.Conv2d(128, nclass, 1, has_bias=False)
|
||||
self.outputs = list()
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
def construct(self, x_sub1, x_sub2, x_sub4):
|
||||
"""Head_construct"""
|
||||
outputs = self.outputs
|
||||
x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
|
||||
|
||||
x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
|
||||
|
||||
h1, w1 = x_cff_12.shape[2:]
|
||||
up_x2 = self.resize_bilinear(x_cff_12, size=(h1 * 2, w1 * 2),
|
||||
align_corners=True)
|
||||
up_x2 = self.conv_cls(up_x2)
|
||||
h2, w2 = up_x2.shape[2:]
|
||||
|
||||
up_x8 = self.resize_bilinear(up_x2, size=(h2 * 4, w2 * 4),
|
||||
align_corners=True) # scale_factor=4,
|
||||
outputs.append(up_x8)
|
||||
outputs.append(up_x2)
|
||||
outputs.append(x_12_cls)
|
||||
outputs.append(x_24_cls)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class _ConvBNReLU(nn.Cell):
|
||||
"""ConvBNRelu"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, dilation=1,
|
||||
groups=1, norm_layer=nn.SyncBatchNorm, bias=False, **kwargs):
|
||||
super(_ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
dilation=dilation,
|
||||
group=1, has_bias=False)
|
||||
self.bn = norm_layer(out_channels, momentum=0.1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class CascadeFeatureFusion12(nn.Cell):
|
||||
"""CFF Unit"""
|
||||
|
||||
def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
|
||||
super(CascadeFeatureFusion12, self).__init__()
|
||||
self.conv_low = nn.SequentialCell(
|
||||
nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_high = nn.SequentialCell(
|
||||
nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
|
||||
self.scalar_cast = ops.ScalarCast()
|
||||
|
||||
self.relu = ms.nn.ReLU()
|
||||
|
||||
def construct(self, x_low, x_high):
|
||||
"""cff_construct"""
|
||||
h, w = x_high.shape[2:]
|
||||
x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
|
||||
x_low = self.conv_low(x_low)
|
||||
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
|
||||
x = self.relu(x)
|
||||
x_low_cls = self.conv_low_cls(x_low)
|
||||
|
||||
return x, x_low_cls
|
||||
|
||||
|
||||
class CascadeFeatureFusion24(nn.Cell):
|
||||
"""CFF Unit"""
|
||||
|
||||
def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
|
||||
super(CascadeFeatureFusion24, self).__init__()
|
||||
self.conv_low = nn.SequentialCell(
|
||||
nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_high = nn.SequentialCell(
|
||||
nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
|
||||
norm_layer(out_channels, momentum=0.1)
|
||||
)
|
||||
self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
|
||||
|
||||
self.resize_bilinear = nn.ResizeBilinear()
|
||||
self.relu = ms.nn.ReLU()
|
||||
|
||||
def construct(self, x_low, x_high):
|
||||
"""ccf_construct"""
|
||||
h, w = x_high.shape[2:]
|
||||
|
||||
x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
|
||||
x_low = self.conv_low(x_low)
|
||||
|
||||
x_high = self.conv_high(x_high)
|
||||
x = x_low + x_high
|
||||
|
||||
x = self.relu(x)
|
||||
x_low_cls = self.conv_low_cls(x_low)
|
||||
|
||||
return x, x_low_cls
|
||||
|
||||
|
||||
class SegBaseModel(nn.Cell):
|
||||
"""Base Model for Semantic Segmentation"""
|
||||
|
||||
def __init__(self, nclass=19, backbone='resnet50', pretrained_base=True, **kwargs):
|
||||
super(SegBaseModel, self).__init__()
|
||||
self.nclass = nclass
|
||||
if backbone == 'resnet50':
|
||||
self.pretrained = get_resnet50v1b()
|
||||
|
||||
def construct(self, x):
|
||||
"""forwarding pre-trained network"""
|
||||
x = self.pretrained.conv1(x)
|
||||
x = self.pretrained.bn1(x)
|
||||
x = self.pretrained.relu(x)
|
||||
x = self.pretrained.maxpool(x)
|
||||
c1 = self.pretrained.layer1(x)
|
||||
c2 = self.pretrained.layer2(c1)
|
||||
c3 = self.pretrained.layer3(c2)
|
||||
c4 = self.pretrained.layer4(c3)
|
||||
|
||||
return c1, c2, c3, c4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
|
@ -0,0 +1,288 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""pretrained model resnet50"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore as ms
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class BottleneckV1b(nn.Cell):
|
||||
"""BottleneckV1b"""
|
||||
def __init__(self, in_channel, out_channel, stride, dilation=1):
|
||||
super().__init__()
|
||||
expansion = 4
|
||||
|
||||
# middle channel num
|
||||
channel = out_channel // expansion
|
||||
self.conv1 = nn.Conv2dBnAct(in_channel, channel, kernel_size=1, stride=1,
|
||||
has_bn=True, pad_mode="same", activation='relu')
|
||||
|
||||
self.conv2 = nn.Conv2dBnAct(channel, channel, kernel_size=3, stride=stride,
|
||||
dilation=dilation, has_bn=True, pad_mode="same", activation='relu')
|
||||
|
||||
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bn=True)
|
||||
|
||||
# whether down-sample identity
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
|
||||
self.down_layer = None
|
||||
if self.down_sample:
|
||||
self.down_layer = nn.Conv2dBnAct(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', has_bn=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.add = ms.ops.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Resnet50v1b(nn.Cell):
|
||||
"""Resnet50v1b"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(Resnet50v1b, self).__init__()
|
||||
|
||||
# initial stage
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0],
|
||||
dilation=1)
|
||||
self.layer2 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1],
|
||||
dilation=1)
|
||||
self.layer3 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
dilation=2)
|
||||
self.layer4 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3],
|
||||
dilation=4)
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, dilation):
|
||||
"""make layers"""
|
||||
layers = []
|
||||
|
||||
resblock = block(in_channel=in_channel,
|
||||
out_channel=out_channel,
|
||||
stride=stride,
|
||||
dilation=dilation)
|
||||
layers.append(resblock)
|
||||
for _ in range(1, layer_num):
|
||||
resblock = block(in_channel=out_channel,
|
||||
out_channel=out_channel,
|
||||
stride=1,
|
||||
dilation=dilation)
|
||||
layers.append(resblock)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""initial stage"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
# four groups
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_resnet50v1b(class_num=1001, ckpt_root='/root/ICNet/ckpt/ResNet50V1B-150_625.ckpt', pretrained=True):
|
||||
"""
|
||||
Get SE-ResNet50 neural network.
|
||||
Default : GE Theta+ version (best)
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
Returns:
|
||||
Cell, cell instance of GENet-ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = get_resnet50v1b(1001)
|
||||
"""
|
||||
|
||||
model = Resnet50v1b(block=BottleneckV1b,
|
||||
layer_nums=[3, 4, 6, 3],
|
||||
in_channels=[64, 256, 512, 1024],
|
||||
out_channels=[256, 512, 1024, 2048],
|
||||
strides=[1, 2, 2, 2],
|
||||
num_classes=class_num)
|
||||
|
||||
if pretrained:
|
||||
pretrained_ckpt = ckpt_root
|
||||
param_dict = load_checkpoint(pretrained_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print("pretrained....")
|
||||
|
||||
return model
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""visualize segmentation"""
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import load_param_into_net
|
||||
from mindspore import load_checkpoint
|
||||
import mindspore.dataset.vision.py_transforms as transforms
|
||||
from src.models.icnet import ICNet
|
||||
|
||||
__all__ = ['get_color_palette', 'set_img_color',
|
||||
'show_prediction', 'show_colorful_images', 'save_colorful_images']
|
||||
|
||||
|
||||
def _img_transform(img):
|
||||
"""img_transform"""
|
||||
totensor = transforms.ToTensor()
|
||||
normalize = transforms.Normalize([.485, .456, .406], [.229, .224, .225])
|
||||
img = totensor(img)
|
||||
img = normalize(img)
|
||||
return img
|
||||
|
||||
|
||||
def set_img_color(img, label, colors, background=0, show255=False):
|
||||
for i in enumerate(colors):
|
||||
if i != background:
|
||||
img[np.where(label == i)] = colors[i]
|
||||
if show255:
|
||||
img[np.where(label == 255)] = 255
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def show_prediction(img, pre, colors, background=0):
|
||||
im = np.array(img, np.uint8)
|
||||
pre = pre
|
||||
set_img_color(im, pre, colors, background)
|
||||
out = np.array(im)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def show_colorful_images(prediction, palettes):
|
||||
im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
|
||||
im.show()
|
||||
|
||||
|
||||
def save_colorful_images(prediction, filename, output_dir, palettes):
|
||||
"""param prediction: [B, H, W, C]"""
|
||||
im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
|
||||
fn = os.path.join(output_dir, filename)
|
||||
out_dir = os.path.split(fn)[0]
|
||||
if not os.path.exists(out_dir):
|
||||
os.mkdir(out_dir)
|
||||
im.save(fn)
|
||||
|
||||
|
||||
def get_color_palette(npimg, dataset='pascal_voc'):
|
||||
"""Visualize image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
npimg : numpy.ndarray
|
||||
Single channel image with shape `H, W, 1`.
|
||||
dataset : str, default: 'pascal_voc'
|
||||
The dataset that model pretrained on. ('pascal_voc', 'ade20k')
|
||||
Returns
|
||||
-------
|
||||
out_img : PIL.Image
|
||||
Image with color palette
|
||||
"""
|
||||
# recovery boundary
|
||||
if dataset in ('pascal_voc', 'pascal_aug'):
|
||||
npimg[npimg == -1] = 255
|
||||
# put colormap
|
||||
out_img = Image.fromarray(npimg.astype('uint8'))
|
||||
out_img.putpalette(vocpalette)
|
||||
return out_img
|
||||
|
||||
|
||||
def _getvocpalette(num_cls):
|
||||
"""get_vocpalette"""
|
||||
n = num_cls
|
||||
palette = [0] * (n * 3)
|
||||
for j in range(0, n):
|
||||
lab = j
|
||||
palette[j * 3 + 0] = 0
|
||||
palette[j * 3 + 1] = 0
|
||||
palette[j * 3 + 2] = 0
|
||||
i = 0
|
||||
while lab > 0:
|
||||
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
||||
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
||||
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
||||
i = i + 1
|
||||
lab >>= 3
|
||||
return palette
|
||||
|
||||
|
||||
vocpalette = _getvocpalette(256)
|
||||
|
||||
cityspalette = [
|
||||
128, 64, 128,
|
||||
244, 35, 232,
|
||||
70, 70, 70,
|
||||
102, 102, 156,
|
||||
190, 153, 153,
|
||||
153, 153, 153,
|
||||
250, 170, 30,
|
||||
220, 220, 0,
|
||||
107, 142, 35,
|
||||
152, 251, 152,
|
||||
0, 130, 180,
|
||||
220, 20, 60,
|
||||
255, 0, 0,
|
||||
0, 0, 142,
|
||||
0, 0, 70,
|
||||
0, 60, 100,
|
||||
0, 80, 100,
|
||||
0, 0, 230,
|
||||
119, 11, 32,
|
||||
]
|
||||
|
||||
|
||||
def _class_to_index(mask):
|
||||
"""assert the value"""
|
||||
values = np.unique(mask)
|
||||
_key = np.array([-1, -1, -1, -1, -1, -1,
|
||||
-1, -1, 0, 1, -1, -1,
|
||||
2, 3, 4, -1, -1, -1,
|
||||
5, -1, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15,
|
||||
-1, -1, 16, 17, 18])
|
||||
_mapping = np.array(range(-1, len(_key) - 1)).astype('int32')
|
||||
for value in values:
|
||||
assert value in _mapping
|
||||
|
||||
index = np.digitize(mask.ravel(), _mapping, right=True)
|
||||
|
||||
return _key[index].reshape(mask.shape)
|
||||
|
||||
|
||||
def _mask_transform(mask):
|
||||
mask = _class_to_index(np.array(mask).astype('int32'))
|
||||
return np.array(mask).astype('int32')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append('/root/ICNet/src/')
|
||||
model = ICNet(nclass=19, backbone='resnet50', istraining=False)
|
||||
ckpt_file_name = '/root/ICNet/ckpt/ICNet-160_93_699.ckpt'
|
||||
param_dict = load_checkpoint(ckpt_file_name)
|
||||
load_param_into_net(model, param_dict)
|
||||
image_path = 'Test/val_lindau_000023_000019_leftImg8bit.png'
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = _img_transform(image)
|
||||
image = Tensor(image)
|
||||
|
||||
expand_dims = ops.ExpandDims()
|
||||
image = expand_dims(image, 0)
|
||||
|
||||
squeeze = ops.Squeeze()
|
||||
outputs = model(image)
|
||||
pred = ops.Argmax(axis=1)(outputs[0])
|
||||
pred = pred.asnumpy()
|
||||
pred = pred.squeeze(0)
|
||||
pred = get_color_palette(pred, "citys")
|
||||
pred.save('Test/visual_pred.png')
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train ICNet and get checkpoint files."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import yaml
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Model
|
||||
from mindspore import context
|
||||
from mindspore import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication import init
|
||||
from mindspore.train.callback import CheckpointConfig
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
|
||||
device_id = int(os.getenv('RANK_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
|
||||
parser = argparse.ArgumentParser(description="ICNet Evaluation")
|
||||
parser.add_argument("--project_path", type=str, help="project_path")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
def train_net():
|
||||
"""train"""
|
||||
set_seed(1234)
|
||||
if device_num > 1:
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
prefix = 'cityscapes.mindrecord'
|
||||
mindrecord_dir = cfg['train']["mindrecord_dir"]
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + '2975')
|
||||
dataset = create_icnet_dataset(mindrecord_file, batch_size=cfg['train']["train_batch_size_percard"],
|
||||
device_num=device_num, rank_id=device_id)
|
||||
|
||||
train_data_size = dataset.get_dataset_size()
|
||||
print("data_size", train_data_size)
|
||||
epoch = cfg["train"]["epochs"]
|
||||
network = ICNetdc() # __init__
|
||||
|
||||
iters_per_epoch = train_data_size
|
||||
total_train_steps = iters_per_epoch * epoch
|
||||
base_lr = cfg["optimizer"]["init_lr"]
|
||||
iter_lr = poly_lr(base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
|
||||
optim = nn.SGD(params=network.trainable_params(), learning_rate=iter_lr, momentum=cfg["optimizer"]["momentum"],
|
||||
weight_decay=cfg["optimizer"]["weight_decay"])
|
||||
|
||||
model = Model(network, optimizer=optim, metrics=None)
|
||||
|
||||
config_ck_train = CheckpointConfig(save_checkpoint_steps=iters_per_epoch * cfg["train"]["save_checkpoint_epochs"],
|
||||
keep_checkpoint_max=cfg["train"]["keep_checkpoint_max"])
|
||||
ckpoint_cb_train = ModelCheckpoint(prefix='ICNet', directory=args_opt.project_path + 'ckpt' + str(device_id),
|
||||
config=config_ck_train)
|
||||
time_cb_train = TimeMonitor(data_size=dataset.get_dataset_size())
|
||||
loss_cb_train = LossMonitor()
|
||||
print("train begins------------------------------")
|
||||
model.train(epoch=epoch, train_dataset=dataset, callbacks=[ckpoint_cb_train, loss_cb_train, time_cb_train],
|
||||
dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Set config file
|
||||
sys.path.append(args_opt.project_path)
|
||||
from src.cityscapes_mindrecord import create_icnet_dataset
|
||||
from src.models.icnet_dc import ICNetdc
|
||||
from src.lr_scheduler import poly_lr
|
||||
config_path = args_opt.project_path + "src/model_utils/icnet.yaml"
|
||||
with open(config_path, "r") as yaml_file:
|
||||
cfg = yaml.load(yaml_file.read())
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
train_net()
|
Loading…
Reference in New Issue