!14528 add ssd, resnet,unet evaluation while training process

From: @zhao_ting_v
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-04-06 10:27:53 +08:00 committed by Gitee
commit c9c8d5fe44
17 changed files with 654 additions and 114 deletions

View File

@ -155,6 +155,7 @@ python eval.py --net=[resnet50|resnet101] --dataset=[cifar10|imagenet2012] --dat
├── src
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├─ eval_callback.py # evaluation callback while training
├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset
├── lr_generator.py # generate learning rate for each step
├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50
@ -323,6 +324,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
#### Evaluation while training
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
### Result
- Training ResNet18 with CIFAR-10 dataset

View File

@ -143,7 +143,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH]
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
├─ eval_callback.py # 训练时推理回调函数
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
├── lr_generator.py # 生成每个步骤的学习率
└── resnet.py # ResNet骨干网络包括ResNet50、ResNet101和SE-ResNet50
├── eval.py # 评估网络
@ -297,6 +298,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
#### 训练时推理
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval`
### 结果
- 使用CIFAR-10数据集训练ResNet18

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation metric."""
from mindspore.communication.management import GlobalComm
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
class ClassifyCorrectCell(nn.Cell):
r"""
Cell that returns correct count of the prediction in classification network.
This Cell accepts a network as arguments.
It returns orrect count of the prediction to calculate the metrics.
Args:
network (Cell): The network Cell.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tuple, containing a scalar correct count of the prediction
Examples:
>>> # For a defined network Net without loss function
>>> net = Net()
>>> eval_net = nn.ClassifyCorrectCell(net)
"""
def __init__(self, network):
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
self._network = network
self.argmax = P.Argmax()
self.equal = P.Equal()
self.cast = P.Cast()
self.reduce_sum = P.ReduceSum()
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
def construct(self, data, label):
outputs = self._network(data)
y_pred = self.argmax(outputs)
y_pred = self.cast(y_pred, mstype.int32)
y_correct = self.equal(y_pred, label)
y_correct = self.cast(y_correct, mstype.float32)
y_correct = self.reduce_sum(y_correct)
total_correct = self.allreduce(y_correct)
return (total_correct,)
class DistAccuracy(nn.Metric):
r"""
Calculates the accuracy for classification data in distributed mode.
The accuracy class creates two local variables, correct number and total number that are used to compute the
frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
idempotent operation that simply divides correct number by total number.
.. math::
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
Args:
eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label).
Examples:
>>> y_correct = Tensor(np.array([20]))
>>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
>>> metric.clear()
>>> metric.update(y_correct)
>>> accuracy = metric.eval()
"""
def __init__(self, batch_size, device_num):
super(DistAccuracy, self).__init__()
self.clear()
self.batch_size = batch_size
self.device_num = device_num
def clear(self):
"""Clears the internal evaluation result."""
self._correct_num = 0
self._total_num = 0
def update(self, *inputs):
"""
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
Args:
inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
`y_correct` is the right prediction count that gathered from all devices
it's a scalar in float type
Raises:
ValueError: If the number of the input is not 1.
"""
if len(inputs) != 1:
raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
y_correct = self._convert_data(inputs[0])
self._correct_num += y_correct
self._total_num += self.batch_size * self.device_num
def eval(self):
"""
Computes the accuracy.
Returns:
Float, the computed result.
Raises:
RuntimeError: If the sample size is 0.
"""
if self._total_num == 0:
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num

View File

@ -31,9 +31,12 @@ from mindspore.common import set_seed
from mindspore.parallel import set_algo_parameters
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
import mindspore.log as logger
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.config import cfg
from src.eval_callback import EvalCallBack
from src.metric import DistAccuracy, ClassifyCorrectCell
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, resnet18, resnet50 or resnet101')
@ -48,6 +51,15 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter head weight parameters, default is False.")
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Evaluation dataset path when run_eval is True')
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=40,
help="Evaluation start epoch when run_eval is True, default is 40.")
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, default is 1.")
args_opt = parser.parse_args()
set_seed(1)
@ -89,6 +101,12 @@ def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
del origin_dict[key]
break
def apply_eval(eval_param):
eval_model = eval_param["model"]
eval_ds = eval_param["dataset"]
metrics_name = eval_param["metrics_name"]
res = eval_model.eval(eval_ds)
return res[metrics_name]
if __name__ == '__main__':
target = args_opt.device_target
@ -185,12 +203,16 @@ if __name__ == '__main__':
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
dist_eval_network = ClassifyCorrectCell(net) if args_opt.run_distribute else None
metrics = {"acc"}
if args_opt.run_distribute:
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)}
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,
amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network)
if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \
args_opt.parameter_server or target == "CPU":
## fp32 training
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
from src.lr_generator import get_thor_damping
damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size)
@ -201,6 +223,8 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False,
frequency=config.frequency)
args_opt.run_eval = False
logger.warning("Thor optimizer not support evaluation while training.")
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
@ -211,7 +235,17 @@ if __name__ == '__main__':
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
if args_opt.run_eval:
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
eval_dataset = create_dataset(dataset_path=args_opt.eval_dataset_path, do_train=False,
batch_size=config.batch_size, target=target)
eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=ckpt_save_dir, besk_ckpt_name="best_acc.ckpt",
metrics_name="acc")
cb += [eval_cb]
# train model
if args_opt.net == "se-resnet50":
config.epoch_size = config.train_epoch_size

View File

@ -123,8 +123,8 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
### Prepare the model
1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`.
2. Change the dataset config in the corresponding config. `src/config_ssd300.py` or `src/config_ssd_mobilenet_v1_fpn.py`.
1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`, `ssd_mobilenet_v1_fpn`, `ssd_resnet50_fpn`.
2. Change the dataset config in the corresponding config. `src/config_ssd300.py`, `src/config_ssd_mobilenet_v1_fpn.py`, `src/config_ssd_resnet50_fpn.py`, `src/config_ssd_vgg16.py`.
3. If you are running with `ssd_mobilenet_v1_fpn`, you need a pretrained model for `mobilenet_v1`. Set the checkpoint path to `feature_extractor_base_param` in `src/config_ssd_mobilenet_v1_fpn.py`. For more detail about training mobilnet_v1, please refer to the mobilenetv1 model.
### Run the scripts
@ -201,6 +201,7 @@ Then you can run everything just like on ascend.
├─ src
├─ __init__.py # init file
├─ box_utils.py # bbox utils
├─ eval_callback.py # evaluation callback when training
├─ eval_utils.py # metrics utils
├─ config.py # total config
├─ dataset.py # create dataset and process dataset
@ -229,6 +230,10 @@ Then you can run everything just like on ascend.
"loss_scale": 1024 # Loss scale
"filter_weight": False # Load parameters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True.
"freeze_layer": "none" # Freeze the backbone parameters or not, support none and backbone.
"run_eval": False # Run evaluation when training
"save_best_ckpt": True # Save best checkpoint when run_eval is True
"eval_start_epoch": 40 # Evaluation start epoch when run_eval is True
"eval_interval": 1 # valuation interval when run_eval is True
"class_num": 81 # Dataset class number
"image_shape": [300, 300] # Image height and width used as input to the model
@ -311,6 +316,10 @@ epoch time: 150753.701, per step time: 329.157
...
```
#### Evaluation while training
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
#### Transfer Training
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.

View File

@ -17,14 +17,12 @@
import os
import argparse
import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.dataset import create_ssd_dataset, create_mindrecord
from src.config import config
from src.eval_utils import metrics
from src.eval_utils import apply_eval
from src.box_utils import default_boxes
def ssd_eval(dataset_path, ckpt_path, anno_json):
@ -50,31 +48,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
load_param_into_net(net, param_dict)
net.set_train(False)
i = batch_size
total = ds.get_dataset_size() * batch_size
start = time.time()
pred_data = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
img_id = data['img_id']
img_np = data['image']
image_shape = data['image_shape']
output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"img_id": int(np.squeeze(img_id[batch_idx])),
"image_shape": image_shape[batch_idx]})
percent = round(i / total * 100., 2)
print(f' {str(percent)} [{i}/{total}]', end='\r')
i += batch_size
cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data, anno_json)
eval_param_dict = {"net": net, "dataset": ds, "anno_json": anno_json}
mAP = apply_eval(eval_param_dict)
print("\n========================================\n")
print(f"mAP: {mAP}")

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -16,8 +16,28 @@
import json
import numpy as np
from mindspore import Tensor
from .config import config
def apply_eval(eval_param_dict):
net = eval_param_dict["net"]
net.set_train(False)
ds = eval_param_dict["dataset"]
anno_json = eval_param_dict["anno_json"]
pred_data = []
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
img_id = data['img_id']
img_np = data['image']
image_shape = data['image_shape']
output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"img_id": int(np.squeeze(img_id[batch_idx])),
"image_shape": image_shape[batch_idx]})
mAP = metrics(pred_data, anno_json)
return mAP
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""

View File

@ -15,6 +15,7 @@
"""Train SSD and get checkpoint files."""
import os
import argparse
import ast
import mindspore.nn as nn
@ -25,11 +26,15 @@ from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, dtype
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.ssd import SSD300, SsdInferWithDecoder, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2,\
ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.config import config
from src.dataset import create_ssd_dataset, create_mindrecord
from src.lr_schedule import get_lr
from src.init_params import init_net_param, filter_checkpoint_parameter_by_list
from src.eval_callback import EvalCallBack
from src.eval_utils import apply_eval
from src.box_utils import default_boxes
set_seed(1)
@ -57,6 +62,14 @@ def get_args():
parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"],
help="freeze the weights of network, support freeze the backbone's weights, "
"default is not freezing.")
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=40,
help="Evaluation start epoch when run_eval is True, default is 40.")
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, default is 1.")
args_opt = parser.parse_args()
return args_opt
@ -170,8 +183,25 @@ def main():
config.momentum, config.weight_decay, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
if args_opt.run_eval:
eval_net = SsdInferWithDecoder(ssd, Tensor(default_boxes), config)
eval_net.set_train(False)
mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False)
eval_dataset = create_ssd_dataset(mindrecord_file, batch_size=args_opt.batch_size, repeat_num=1,
is_training=False, use_multiprocessing=False)
if args_opt.dataset == "coco":
anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
elif args_opt.dataset == "voc":
anno_json = os.path.join(config.voc_root, config.voc_json)
else:
raise ValueError('SSD eval only support dataset mode is coco and voc!')
eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt",
metrics_name="mAP")
callback.append(eval_cb)
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "sink" and args_opt.run_platform != "CPU":

View File

@ -128,6 +128,7 @@ Then you can run everything just like on ascend.
│ ├──config.py // parameter configuration
│ ├──data_loader.py // creating dataset
│ ├──loss.py // loss
│ ├──eval_callback.py // evaluation callback while training
│ ├──utils.py // General components (callback function)
│ ├──unet_medical // Unet medical architecture
├──__init__.py // init file
@ -168,6 +169,11 @@ Parameters for both training and evaluation can be set in config.py
'resume_ckpt': './', # pretrain model path
'transfer_training': False # whether do transfer training
'filter_weight': ["final.weight"] # weight name to filter while doing transfer training
'run_eval': False # Run evaluation when training
'save_best_ckpt': True # Save best checkpoint when run_eval is True
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
'eval_interval': 1 # valuation interval when run_eval is True
```
- config for Unet++, cell nuclei dataset
@ -193,6 +199,10 @@ Parameters for both training and evaluation can be set in config.py
'resume_ckpt': './', # pretrain model path
'transfer_training': False # whether do transfer training
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # weight name to filter while doing transfer training
'run_eval': False # Run evaluation when training
'save_best_ckpt': True # Save best checkpoint when run_eval is True
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
'eval_interval': 1 # valuation interval when run_eval is True
```
## [Training Process](#contents)
@ -245,6 +255,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
```
#### Evaluation while training
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True.
## [Evaluation Process](#contents)
### Evaluation

View File

@ -132,6 +132,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
│ ├──config.py // 参数配置
│ ├──data_loader.py // 数据处理
│ ├──loss.py // 损失函数
│ ├─ eval_callback.py // 训练时推理回调函数
│ ├──utils.py // 通用组件(回调函数)
│ ├──unet_medical // 医学图像处理Unet结构
├──__init__.py
@ -247,6 +248,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
```
#### 训练时推理
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics`
## 评估过程
### 评估

View File

@ -16,10 +16,6 @@
import os
import argparse
import logging
import cv2
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -27,76 +23,11 @@ from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
from src.utils import UnetEval
from src.utils import UnetEval, TempLoss, dice_coeff
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self):
super(TempLoss, self).__init__()
self.identity = F.identity()
def construct(self, logits, label):
return self.identity(logits)
class dice_coeff(nn.Metric):
def __init__(self):
super(dice_coeff, self).__init__()
self.clear()
def clear(self):
self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y = y.transpose(0, 2, 3, 1)
b, h, w, c = y.shape
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
if cfg_unet["eval_activate"].lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if cfg_unet["eval_resize"]:
y_pred = []
for i in range(cfg_unet["num_classes"]):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
elif cfg_unet["eval_activate"].lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = []
for i in range(cfg_unet["num_classes"]):
if cfg_unet["eval_resize"]:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
y_pred.append(np.float32(y_argmax == i))
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2*float(inter)/float(union+1e-6)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
self._dice_coeff_sum += single_dice_coeff
self._iou_sum += single_iou
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
def test_net(data_dir,
ckpt_path,
cross_valid_ind=1,
@ -119,7 +50,7 @@ def test_net(data_dir,
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)})
print("============== Starting Evaluating ============")
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -41,7 +41,7 @@ class MultiCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(MultiCrossEntropyWithLogits, self).__init__()
self.loss = CrossEntropyWithLogits()
self.squeeze = F.Squeeze()
self.squeeze = F.Squeeze(axis=0)
def construct(self, logits, label):
total_loss = 0

View File

@ -14,6 +14,7 @@
# ============================================================================
import time
import cv2
import numpy as np
from PIL import Image
from mindspore import nn
@ -25,20 +26,100 @@ class UnetEval(nn.Cell):
"""
Add Unet evaluation activation.
"""
def __init__(self, net):
def __init__(self, net, need_slice=False):
super(UnetEval, self).__init__()
self.net = net
self.need_slice = need_slice
self.transpose = ops.Transpose()
self.softmax = ops.Softmax(axis=-1)
self.argmax = ops.Argmax(axis=-1)
self.squeeze = ops.Squeeze(axis=0)
def construct(self, x):
out = self.net(x)
if self.need_slice:
out = self.squeeze(out[-1:])
out = self.transpose(out, (0, 2, 3, 1))
softmax_out = self.softmax(out)
argmax_out = self.argmax(out)
return (softmax_out, argmax_out)
class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self):
super(TempLoss, self).__init__()
self.identity = ops.identity()
def construct(self, logits, label):
return self.identity(logits)
def apply_eval(eval_param_dict):
"""run Evaluation"""
model = eval_param_dict["model"]
dataset = eval_param_dict["dataset"]
metrics_name = eval_param_dict["metrics_name"]
index = 0 if metrics_name == "dice_coeff" else 1
eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name][index]
return eval_score
class dice_coeff(nn.Metric):
"""Unet Metric, return dice coefficient and IOU."""
def __init__(self, cfg_unet, print_res=True):
super(dice_coeff, self).__init__()
self.clear()
self.cfg_unet = cfg_unet
self.print_res = print_res
def clear(self):
self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y = y.transpose(0, 2, 3, 1)
b, h, w, c = y.shape
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
if self.cfg_unet["eval_activate"].lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if self.cfg_unet["eval_resize"]:
y_pred = []
for i in range(self.cfg_unet["num_classes"]):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
elif self.cfg_unet["eval_activate"].lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = []
for i in range(self.cfg_unet["num_classes"]):
if self.cfg_unet["eval_resize"]:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
y_pred.append(np.float32(y_argmax == i))
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2 * float(inter) / float(union+1e-6)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
if self.print_res:
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
self._dice_coeff_sum += single_dice_coeff
self._iou_sum += single_iou
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1):

View File

@ -30,23 +30,25 @@ from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
from src.config import cfg_unet
from src.eval_callback import EvalCallBack
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
mindspore.set_seed(1)
def train_net(data_dir,
def train_net(args_opt,
cross_valid_ind=1,
epochs=400,
batch_size=16,
lr=0.0001,
run_distribute=False,
cfg=None):
rank = 0
group_size = 1
data_dir = args_opt.data_url
run_distribute = args_opt.run_distribute
if run_distribute:
init()
group_size = get_group_size()
@ -55,12 +57,13 @@ def train_net(data_dir,
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
gradients_mean=False)
need_slice = False
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
need_slice = cfg['use_ds']
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
@ -83,12 +86,15 @@ def train_net(data_dir,
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
is_train=True, augment=True, split=0.8, rank=rank,
group_size=group_size)
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
eval_resize=cfg["eval_resize"], split=0.8,
python_multiprocessing=False)
else:
repeat = epochs
dataset_sink_mode = False
per_print_times = 1
train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute,
cfg["crop"], cfg['img_size'])
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
run_distribute, cfg["crop"], cfg['img_size'])
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
@ -106,6 +112,15 @@ def train_net(data_dir,
print("============== Starting Training ==============")
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
if args_opt.run_eval:
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
metrics_name=args_opt.eval_metrics)
callbacks.append(eval_cb)
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
print("============== End Training ==============")
@ -117,6 +132,17 @@ def get_args():
help='data directory')
parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
default=False, help='Run distribute, default: false.')
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=0,
help="Evaluation start epoch when run_eval is True, default is 0.")
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, default is 1.")
parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"),
help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], "
"default is dice_coeff.")
return parser.parse_args()
@ -127,10 +153,9 @@ if __name__ == '__main__':
print("Training setting:", args)
epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs']
train_net(data_dir=args.data_url,
train_net(args_opt=args,
cross_valid_ind=cfg_unet['cross_valid_ind'],
epochs=epoch_size,
batch_size=cfg_unet['batchsize'],
lr=cfg_unet['lr'],
run_distribute=args.run_distribute,
cfg=cfg_unet)