forked from OSSInnovation/mindspore
!14528 add ssd, resnet,unet evaluation while training process
From: @zhao_ting_v Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
c9c8d5fe44
|
@ -155,6 +155,7 @@ python eval.py --net=[resnet50|resnet101] --dataset=[cifar10|imagenet2012] --dat
|
|||
├── src
|
||||
├── config.py # parameter configuration
|
||||
├── dataset.py # data preprocessing
|
||||
├─ eval_callback.py # evaluation callback while training
|
||||
├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset
|
||||
├── lr_generator.py # generate learning rate for each step
|
||||
├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50
|
||||
|
@ -323,6 +324,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen
|
|||
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
|
||||
```
|
||||
|
||||
#### Evaluation while training
|
||||
|
||||
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
|
||||
|
||||
### Result
|
||||
|
||||
- Training ResNet18 with CIFAR-10 dataset
|
||||
|
|
|
@ -143,7 +143,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH]
|
|||
├── src
|
||||
├── config.py # 参数配置
|
||||
├── dataset.py # 数据预处理
|
||||
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
|
||||
├─ eval_callback.py # 训练时推理回调函数
|
||||
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
|
||||
├── lr_generator.py # 生成每个步骤的学习率
|
||||
└── resnet.py # ResNet骨干网络,包括ResNet50、ResNet101和SE-ResNet50
|
||||
├── eval.py # 评估网络
|
||||
|
@ -297,6 +298,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen
|
|||
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
```
|
||||
|
||||
#### 训练时推理
|
||||
|
||||
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` 。
|
||||
|
||||
### 结果
|
||||
|
||||
- 使用CIFAR-10数据集训练ResNet18
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluation metric."""
|
||||
|
||||
from mindspore.communication.management import GlobalComm
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class ClassifyCorrectCell(nn.Cell):
|
||||
r"""
|
||||
Cell that returns correct count of the prediction in classification network.
|
||||
This Cell accepts a network as arguments.
|
||||
It returns orrect count of the prediction to calculate the metrics.
|
||||
|
||||
Args:
|
||||
network (Cell): The network Cell.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
Tuple, containing a scalar correct count of the prediction
|
||||
|
||||
Examples:
|
||||
>>> # For a defined network Net without loss function
|
||||
>>> net = Net()
|
||||
>>> eval_net = nn.ClassifyCorrectCell(net)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
|
||||
self._network = network
|
||||
self.argmax = P.Argmax()
|
||||
self.equal = P.Equal()
|
||||
self.cast = P.Cast()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
y_pred = self.argmax(outputs)
|
||||
y_pred = self.cast(y_pred, mstype.int32)
|
||||
y_correct = self.equal(y_pred, label)
|
||||
y_correct = self.cast(y_correct, mstype.float32)
|
||||
y_correct = self.reduce_sum(y_correct)
|
||||
total_correct = self.allreduce(y_correct)
|
||||
return (total_correct,)
|
||||
|
||||
|
||||
class DistAccuracy(nn.Metric):
|
||||
r"""
|
||||
Calculates the accuracy for classification data in distributed mode.
|
||||
The accuracy class creates two local variables, correct number and total number that are used to compute the
|
||||
frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
|
||||
idempotent operation that simply divides correct number by total number.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
|
||||
|
||||
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
|
||||
|
||||
Args:
|
||||
eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label).
|
||||
|
||||
Examples:
|
||||
>>> y_correct = Tensor(np.array([20]))
|
||||
>>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
|
||||
>>> metric.clear()
|
||||
>>> metric.update(y_correct)
|
||||
>>> accuracy = metric.eval()
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, device_num):
|
||||
super(DistAccuracy, self).__init__()
|
||||
self.clear()
|
||||
self.batch_size = batch_size
|
||||
self.device_num = device_num
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal evaluation result."""
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
|
||||
|
||||
Args:
|
||||
inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
|
||||
`y_correct` is the right prediction count that gathered from all devices
|
||||
it's a scalar in float type
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the input is not 1.
|
||||
"""
|
||||
|
||||
if len(inputs) != 1:
|
||||
raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
|
||||
y_correct = self._convert_data(inputs[0])
|
||||
self._correct_num += y_correct
|
||||
self._total_num += self.batch_size * self.device_num
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Computes the accuracy.
|
||||
|
||||
Returns:
|
||||
Float, the computed result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the sample size is 0.
|
||||
"""
|
||||
|
||||
if self._total_num == 0:
|
||||
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
|
||||
return self._correct_num / self._total_num
|
|
@ -31,9 +31,12 @@ from mindspore.common import set_seed
|
|||
from mindspore.parallel import set_algo_parameters
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
import mindspore.log as logger
|
||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.config import cfg
|
||||
from src.eval_callback import EvalCallBack
|
||||
from src.metric import DistAccuracy, ClassifyCorrectCell
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--net', type=str, default=None, help='Resnet Model, resnet18, resnet50 or resnet101')
|
||||
|
@ -48,6 +51,15 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch
|
|||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter head weight parameters, default is False.")
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Evaluation dataset path when run_eval is True')
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=40,
|
||||
help="Evaluation start epoch when run_eval is True, default is 40.")
|
||||
parser.add_argument("--eval_interval", type=int, default=1,
|
||||
help="Evaluation interval when run_eval is True, default is 1.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
|
@ -89,6 +101,12 @@ def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
|
|||
del origin_dict[key]
|
||||
break
|
||||
|
||||
def apply_eval(eval_param):
|
||||
eval_model = eval_param["model"]
|
||||
eval_ds = eval_param["dataset"]
|
||||
metrics_name = eval_param["metrics_name"]
|
||||
res = eval_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
@ -185,12 +203,16 @@ if __name__ == '__main__':
|
|||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
dist_eval_network = ClassifyCorrectCell(net) if args_opt.run_distribute else None
|
||||
metrics = {"acc"}
|
||||
if args_opt.run_distribute:
|
||||
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)}
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,
|
||||
amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network)
|
||||
if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \
|
||||
args_opt.parameter_server or target == "CPU":
|
||||
## fp32 training
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)
|
||||
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
|
||||
from src.lr_generator import get_thor_damping
|
||||
damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size)
|
||||
|
@ -201,6 +223,8 @@ if __name__ == '__main__':
|
|||
loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False,
|
||||
frequency=config.frequency)
|
||||
args_opt.run_eval = False
|
||||
logger.warning("Thor optimizer not support evaluation while training.")
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
|
@ -211,7 +235,17 @@ if __name__ == '__main__':
|
|||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
if args_opt.run_eval:
|
||||
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
|
||||
eval_dataset = create_dataset(dataset_path=args_opt.eval_dataset_path, do_train=False,
|
||||
batch_size=config.batch_size, target=target)
|
||||
eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=ckpt_save_dir, besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="acc")
|
||||
cb += [eval_cb]
|
||||
# train model
|
||||
if args_opt.net == "se-resnet50":
|
||||
config.epoch_size = config.train_epoch_size
|
||||
|
|
|
@ -123,8 +123,8 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
|
|||
|
||||
### Prepare the model
|
||||
|
||||
1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`.
|
||||
2. Change the dataset config in the corresponding config. `src/config_ssd300.py` or `src/config_ssd_mobilenet_v1_fpn.py`.
|
||||
1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`, `ssd_mobilenet_v1_fpn`, `ssd_resnet50_fpn`.
|
||||
2. Change the dataset config in the corresponding config. `src/config_ssd300.py`, `src/config_ssd_mobilenet_v1_fpn.py`, `src/config_ssd_resnet50_fpn.py`, `src/config_ssd_vgg16.py`.
|
||||
3. If you are running with `ssd_mobilenet_v1_fpn`, you need a pretrained model for `mobilenet_v1`. Set the checkpoint path to `feature_extractor_base_param` in `src/config_ssd_mobilenet_v1_fpn.py`. For more detail about training mobilnet_v1, please refer to the mobilenetv1 model.
|
||||
|
||||
### Run the scripts
|
||||
|
@ -201,6 +201,7 @@ Then you can run everything just like on ascend.
|
|||
├─ src
|
||||
├─ __init__.py # init file
|
||||
├─ box_utils.py # bbox utils
|
||||
├─ eval_callback.py # evaluation callback when training
|
||||
├─ eval_utils.py # metrics utils
|
||||
├─ config.py # total config
|
||||
├─ dataset.py # create dataset and process dataset
|
||||
|
@ -229,6 +230,10 @@ Then you can run everything just like on ascend.
|
|||
"loss_scale": 1024 # Loss scale
|
||||
"filter_weight": False # Load parameters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True.
|
||||
"freeze_layer": "none" # Freeze the backbone parameters or not, support none and backbone.
|
||||
"run_eval": False # Run evaluation when training
|
||||
"save_best_ckpt": True # Save best checkpoint when run_eval is True
|
||||
"eval_start_epoch": 40 # Evaluation start epoch when run_eval is True
|
||||
"eval_interval": 1 # valuation interval when run_eval is True
|
||||
|
||||
"class_num": 81 # Dataset class number
|
||||
"image_shape": [300, 300] # Image height and width used as input to the model
|
||||
|
@ -311,6 +316,10 @@ epoch time: 150753.701, per step time: 329.157
|
|||
...
|
||||
```
|
||||
|
||||
#### Evaluation while training
|
||||
|
||||
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
|
||||
|
||||
#### Transfer Training
|
||||
|
||||
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||
from src.config import config
|
||||
from src.eval_utils import metrics
|
||||
from src.eval_utils import apply_eval
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
def ssd_eval(dataset_path, ckpt_path, anno_json):
|
||||
|
@ -50,31 +48,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
|
|||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
percent = round(i / total * 100., 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += batch_size
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data, anno_json)
|
||||
eval_param_dict = {"net": net, "dataset": ds, "anno_json": anno_json}
|
||||
mAP = apply_eval(eval_param_dict)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
|
@ -16,8 +16,28 @@
|
|||
|
||||
import json
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from .config import config
|
||||
|
||||
def apply_eval(eval_param_dict):
|
||||
net = eval_param_dict["net"]
|
||||
net.set_train(False)
|
||||
ds = eval_param_dict["dataset"]
|
||||
anno_json = eval_param_dict["anno_json"]
|
||||
pred_data = []
|
||||
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
mAP = metrics(pred_data, anno_json)
|
||||
return mAP
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Train SSD and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.nn as nn
|
||||
|
@ -25,11 +26,15 @@ from mindspore.train import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed, dtype
|
||||
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||
from src.ssd import SSD300, SsdInferWithDecoder, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2,\
|
||||
ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||
from src.config import config
|
||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter_by_list
|
||||
from src.eval_callback import EvalCallBack
|
||||
from src.eval_utils import apply_eval
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -57,6 +62,14 @@ def get_args():
|
|||
parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"],
|
||||
help="freeze the weights of network, support freeze the backbone's weights, "
|
||||
"default is not freezing.")
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=40,
|
||||
help="Evaluation start epoch when run_eval is True, default is 40.")
|
||||
parser.add_argument("--eval_interval", type=int, default=1,
|
||||
help="Evaluation interval when run_eval is True, default is 1.")
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
@ -170,8 +183,25 @@ def main():
|
|||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
|
||||
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
if args_opt.run_eval:
|
||||
eval_net = SsdInferWithDecoder(ssd, Tensor(default_boxes), config)
|
||||
eval_net.set_train(False)
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False)
|
||||
eval_dataset = create_ssd_dataset(mindrecord_file, batch_size=args_opt.batch_size, repeat_num=1,
|
||||
is_training=False, use_multiprocessing=False)
|
||||
if args_opt.dataset == "coco":
|
||||
anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
|
||||
elif args_opt.dataset == "voc":
|
||||
anno_json = os.path.join(config.voc_root, config.voc_json)
|
||||
else:
|
||||
raise ValueError('SSD eval only support dataset mode is coco and voc!')
|
||||
eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt",
|
||||
metrics_name="mAP")
|
||||
callback.append(eval_cb)
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "sink" and args_opt.run_platform != "CPU":
|
||||
|
|
|
@ -128,6 +128,7 @@ Then you can run everything just like on ascend.
|
|||
│ ├──config.py // parameter configuration
|
||||
│ ├──data_loader.py // creating dataset
|
||||
│ ├──loss.py // loss
|
||||
│ ├──eval_callback.py // evaluation callback while training
|
||||
│ ├──utils.py // General components (callback function)
|
||||
│ ├──unet_medical // Unet medical architecture
|
||||
├──__init__.py // init file
|
||||
|
@ -168,6 +169,11 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'resume_ckpt': './', # pretrain model path
|
||||
'transfer_training': False # whether do transfer training
|
||||
'filter_weight': ["final.weight"] # weight name to filter while doing transfer training
|
||||
'run_eval': False # Run evaluation when training
|
||||
'save_best_ckpt': True # Save best checkpoint when run_eval is True
|
||||
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
|
||||
'eval_interval': 1 # valuation interval when run_eval is True
|
||||
|
||||
```
|
||||
|
||||
- config for Unet++, cell nuclei dataset
|
||||
|
@ -193,6 +199,10 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'resume_ckpt': './', # pretrain model path
|
||||
'transfer_training': False # whether do transfer training
|
||||
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # weight name to filter while doing transfer training
|
||||
'run_eval': False # Run evaluation when training
|
||||
'save_best_ckpt': True # Save best checkpoint when run_eval is True
|
||||
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
|
||||
'eval_interval': 1 # valuation interval when run_eval is True
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
@ -245,6 +255,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
|
|||
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
||||
```
|
||||
|
||||
#### Evaluation while training
|
||||
|
||||
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
|
|
@ -132,6 +132,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
|||
│ ├──config.py // 参数配置
|
||||
│ ├──data_loader.py // 数据处理
|
||||
│ ├──loss.py // 损失函数
|
||||
│ ├─ eval_callback.py // 训练时推理回调函数
|
||||
│ ├──utils.py // 通用组件(回调函数)
|
||||
│ ├──unet_medical // 医学图像处理Unet结构
|
||||
├──__init__.py
|
||||
|
@ -247,6 +248,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
|
|||
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
||||
```
|
||||
|
||||
#### 训练时推理
|
||||
|
||||
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` 。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
|
|
@ -16,10 +16,6 @@
|
|||
import os
|
||||
import argparse
|
||||
import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as F
|
||||
from mindspore import context, Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
@ -27,76 +23,11 @@ from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
|||
from src.unet_medical import UNetMedical
|
||||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.config import cfg_unet
|
||||
from src.utils import UnetEval
|
||||
from src.utils import UnetEval, TempLoss, dice_coeff
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
|
||||
class TempLoss(nn.Cell):
|
||||
"""A temp loss cell."""
|
||||
def __init__(self):
|
||||
super(TempLoss, self).__init__()
|
||||
self.identity = F.identity()
|
||||
def construct(self, logits, label):
|
||||
return self.identity(logits)
|
||||
|
||||
|
||||
class dice_coeff(nn.Metric):
|
||||
def __init__(self):
|
||||
super(dice_coeff, self).__init__()
|
||||
self.clear()
|
||||
def clear(self):
|
||||
self._dice_coeff_sum = 0
|
||||
self._iou_sum = 0
|
||||
self._samples_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
|
||||
y = self._convert_data(inputs[1])
|
||||
self._samples_num += y.shape[0]
|
||||
y = y.transpose(0, 2, 3, 1)
|
||||
b, h, w, c = y.shape
|
||||
if b != 1:
|
||||
raise ValueError('Batch size should be 1 when in evaluation.')
|
||||
y = y.reshape((h, w, c))
|
||||
if cfg_unet["eval_activate"].lower() == "softmax":
|
||||
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
|
||||
if cfg_unet["eval_resize"]:
|
||||
y_pred = []
|
||||
for i in range(cfg_unet["num_classes"]):
|
||||
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
y_pred = y_softmax
|
||||
elif cfg_unet["eval_activate"].lower() == "argmax":
|
||||
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
|
||||
y_pred = []
|
||||
for i in range(cfg_unet["num_classes"]):
|
||||
if cfg_unet["eval_resize"]:
|
||||
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
|
||||
else:
|
||||
y_pred.append(np.float32(y_argmax == i))
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
raise ValueError('config eval_activate should be softmax or argmax.')
|
||||
y_pred = y_pred.astype(np.float32)
|
||||
inter = np.dot(y_pred.flatten(), y.flatten())
|
||||
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
||||
|
||||
single_dice_coeff = 2*float(inter)/float(union+1e-6)
|
||||
single_iou = single_dice_coeff / (2 - single_dice_coeff)
|
||||
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
|
||||
self._dice_coeff_sum += single_dice_coeff
|
||||
self._iou_sum += single_iou
|
||||
|
||||
def eval(self):
|
||||
if self._samples_num == 0:
|
||||
raise RuntimeError('Total samples num must not be 0.')
|
||||
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
|
||||
|
||||
|
||||
def test_net(data_dir,
|
||||
ckpt_path,
|
||||
cross_valid_ind=1,
|
||||
|
@ -119,7 +50,7 @@ def test_net(data_dir,
|
|||
else:
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
||||
do_crop=cfg['crop'], img_size=cfg['img_size'])
|
||||
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
|
||||
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)})
|
||||
|
||||
print("============== Starting Evaluating ============")
|
||||
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
|
@ -41,7 +41,7 @@ class MultiCrossEntropyWithLogits(nn.Cell):
|
|||
def __init__(self):
|
||||
super(MultiCrossEntropyWithLogits, self).__init__()
|
||||
self.loss = CrossEntropyWithLogits()
|
||||
self.squeeze = F.Squeeze()
|
||||
self.squeeze = F.Squeeze(axis=0)
|
||||
|
||||
def construct(self, logits, label):
|
||||
total_loss = 0
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore import nn
|
||||
|
@ -25,20 +26,100 @@ class UnetEval(nn.Cell):
|
|||
"""
|
||||
Add Unet evaluation activation.
|
||||
"""
|
||||
def __init__(self, net):
|
||||
def __init__(self, net, need_slice=False):
|
||||
super(UnetEval, self).__init__()
|
||||
self.net = net
|
||||
self.need_slice = need_slice
|
||||
self.transpose = ops.Transpose()
|
||||
self.softmax = ops.Softmax(axis=-1)
|
||||
self.argmax = ops.Argmax(axis=-1)
|
||||
self.squeeze = ops.Squeeze(axis=0)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.net(x)
|
||||
if self.need_slice:
|
||||
out = self.squeeze(out[-1:])
|
||||
out = self.transpose(out, (0, 2, 3, 1))
|
||||
softmax_out = self.softmax(out)
|
||||
argmax_out = self.argmax(out)
|
||||
return (softmax_out, argmax_out)
|
||||
|
||||
class TempLoss(nn.Cell):
|
||||
"""A temp loss cell."""
|
||||
def __init__(self):
|
||||
super(TempLoss, self).__init__()
|
||||
self.identity = ops.identity()
|
||||
def construct(self, logits, label):
|
||||
return self.identity(logits)
|
||||
|
||||
def apply_eval(eval_param_dict):
|
||||
"""run Evaluation"""
|
||||
model = eval_param_dict["model"]
|
||||
dataset = eval_param_dict["dataset"]
|
||||
metrics_name = eval_param_dict["metrics_name"]
|
||||
index = 0 if metrics_name == "dice_coeff" else 1
|
||||
eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name][index]
|
||||
return eval_score
|
||||
|
||||
class dice_coeff(nn.Metric):
|
||||
"""Unet Metric, return dice coefficient and IOU."""
|
||||
def __init__(self, cfg_unet, print_res=True):
|
||||
super(dice_coeff, self).__init__()
|
||||
self.clear()
|
||||
self.cfg_unet = cfg_unet
|
||||
self.print_res = print_res
|
||||
|
||||
def clear(self):
|
||||
self._dice_coeff_sum = 0
|
||||
self._iou_sum = 0
|
||||
self._samples_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
|
||||
y = self._convert_data(inputs[1])
|
||||
self._samples_num += y.shape[0]
|
||||
y = y.transpose(0, 2, 3, 1)
|
||||
b, h, w, c = y.shape
|
||||
if b != 1:
|
||||
raise ValueError('Batch size should be 1 when in evaluation.')
|
||||
y = y.reshape((h, w, c))
|
||||
if self.cfg_unet["eval_activate"].lower() == "softmax":
|
||||
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
|
||||
if self.cfg_unet["eval_resize"]:
|
||||
y_pred = []
|
||||
for i in range(self.cfg_unet["num_classes"]):
|
||||
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
y_pred = y_softmax
|
||||
elif self.cfg_unet["eval_activate"].lower() == "argmax":
|
||||
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
|
||||
y_pred = []
|
||||
for i in range(self.cfg_unet["num_classes"]):
|
||||
if self.cfg_unet["eval_resize"]:
|
||||
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
|
||||
else:
|
||||
y_pred.append(np.float32(y_argmax == i))
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
raise ValueError('config eval_activate should be softmax or argmax.')
|
||||
y_pred = y_pred.astype(np.float32)
|
||||
inter = np.dot(y_pred.flatten(), y.flatten())
|
||||
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
||||
|
||||
single_dice_coeff = 2 * float(inter) / float(union+1e-6)
|
||||
single_iou = single_dice_coeff / (2 - single_dice_coeff)
|
||||
if self.print_res:
|
||||
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
|
||||
self._dice_coeff_sum += single_dice_coeff
|
||||
self._iou_sum += single_iou
|
||||
|
||||
def eval(self):
|
||||
if self._samples_num == 0:
|
||||
raise RuntimeError('Total samples num must not be 0.')
|
||||
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
|
||||
|
||||
class StepLossTimeMonitor(Callback):
|
||||
|
||||
def __init__(self, batch_size, per_print_times=1):
|
||||
|
|
|
@ -30,23 +30,25 @@ from src.unet_medical import UNetMedical
|
|||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
||||
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
|
||||
from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list
|
||||
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
|
||||
from src.config import cfg_unet
|
||||
from src.eval_callback import EvalCallBack
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
mindspore.set_seed(1)
|
||||
|
||||
def train_net(data_dir,
|
||||
def train_net(args_opt,
|
||||
cross_valid_ind=1,
|
||||
epochs=400,
|
||||
batch_size=16,
|
||||
lr=0.0001,
|
||||
run_distribute=False,
|
||||
cfg=None):
|
||||
rank = 0
|
||||
group_size = 1
|
||||
data_dir = args_opt.data_url
|
||||
run_distribute = args_opt.run_distribute
|
||||
if run_distribute:
|
||||
init()
|
||||
group_size = get_group_size()
|
||||
|
@ -55,12 +57,13 @@ def train_net(data_dir,
|
|||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=group_size,
|
||||
gradients_mean=False)
|
||||
|
||||
need_slice = False
|
||||
if cfg['model'] == 'unet_medical':
|
||||
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
elif cfg['model'] == 'unet_nested':
|
||||
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
|
||||
use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
|
||||
need_slice = cfg['use_ds']
|
||||
elif cfg['model'] == 'unet_simple':
|
||||
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
|
||||
else:
|
||||
|
@ -83,12 +86,15 @@ def train_net(data_dir,
|
|||
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
|
||||
is_train=True, augment=True, split=0.8, rank=rank,
|
||||
group_size=group_size)
|
||||
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=0.8,
|
||||
python_multiprocessing=False)
|
||||
else:
|
||||
repeat = epochs
|
||||
dataset_sink_mode = False
|
||||
per_print_times = 1
|
||||
train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute,
|
||||
cfg["crop"], cfg['img_size'])
|
||||
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
|
||||
run_distribute, cfg["crop"], cfg['img_size'])
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
print("dataset length is:", train_data_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
|
@ -106,6 +112,15 @@ def train_net(data_dir,
|
|||
|
||||
print("============== Starting Training ==============")
|
||||
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
|
||||
if args_opt.run_eval:
|
||||
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
|
||||
metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
|
||||
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
|
||||
metrics_name=args_opt.eval_metrics)
|
||||
callbacks.append(eval_cb)
|
||||
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
|
||||
print("============== End Training ==============")
|
||||
|
||||
|
@ -117,6 +132,17 @@ def get_args():
|
|||
help='data directory')
|
||||
parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
|
||||
default=False, help='Run distribute, default: false.')
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=0,
|
||||
help="Evaluation start epoch when run_eval is True, default is 0.")
|
||||
parser.add_argument("--eval_interval", type=int, default=1,
|
||||
help="Evaluation interval when run_eval is True, default is 1.")
|
||||
parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"),
|
||||
help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], "
|
||||
"default is dice_coeff.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -127,10 +153,9 @@ if __name__ == '__main__':
|
|||
print("Training setting:", args)
|
||||
|
||||
epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs']
|
||||
train_net(data_dir=args.data_url,
|
||||
train_net(args_opt=args,
|
||||
cross_valid_ind=cfg_unet['cross_valid_ind'],
|
||||
epochs=epoch_size,
|
||||
batch_size=cfg_unet['batchsize'],
|
||||
lr=cfg_unet['lr'],
|
||||
run_distribute=args.run_distribute,
|
||||
cfg=cfg_unet)
|
||||
|
|
Loading…
Reference in New Issue