diff --git a/model_zoo/official/cv/resnet/README.md b/model_zoo/official/cv/resnet/README.md index a677b9e5b5..716dc3f46b 100644 --- a/model_zoo/official/cv/resnet/README.md +++ b/model_zoo/official/cv/resnet/README.md @@ -155,6 +155,7 @@ python eval.py --net=[resnet50|resnet101] --dataset=[cifar10|imagenet2012] --dat ├── src ├── config.py # parameter configuration ├── dataset.py # data preprocessing + ├─ eval_callback.py # evaluation callback while training ├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset ├── lr_generator.py # generate learning rate for each step ├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 @@ -323,6 +324,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) ``` +#### Evaluation while training + +You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. + ### Result - Training ResNet18 with CIFAR-10 dataset diff --git a/model_zoo/official/cv/resnet/README_CN.md b/model_zoo/official/cv/resnet/README_CN.md index 5d5cb5d389..fe06881677 100755 --- a/model_zoo/official/cv/resnet/README_CN.md +++ b/model_zoo/official/cv/resnet/README_CN.md @@ -143,7 +143,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] ├── src ├── config.py # 参数配置 ├── dataset.py # 数据预处理 - ├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义 + ├─ eval_callback.py # 训练时推理回调函数 + ├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义 ├── lr_generator.py # 生成每个步骤的学习率 └── resnet.py # ResNet骨干网络,包括ResNet50、ResNet101和SE-ResNet50 ├── eval.py # 评估网络 @@ -297,6 +298,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) ``` +#### 训练时推理 + +训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` 。 + ### 结果 - 使用CIFAR-10数据集训练ResNet18 diff --git a/model_zoo/official/cv/resnet/src/eval_callback.py b/model_zoo/official/cv/resnet/src/eval_callback.py new file mode 100644 index 0000000000..205fce0eaf --- /dev/null +++ b/model_zoo/official/cv/resnet/src/eval_callback.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) diff --git a/model_zoo/official/cv/resnet/src/metric.py b/model_zoo/official/cv/resnet/src/metric.py new file mode 100644 index 0000000000..497dd6640c --- /dev/null +++ b/model_zoo/official/cv/resnet/src/metric.py @@ -0,0 +1,132 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluation metric.""" + +from mindspore.communication.management import GlobalComm +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.common.dtype as mstype + + +class ClassifyCorrectCell(nn.Cell): + r""" + Cell that returns correct count of the prediction in classification network. + This Cell accepts a network as arguments. + It returns orrect count of the prediction to calculate the metrics. + + Args: + network (Cell): The network Cell. + + Inputs: + - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + + Outputs: + Tuple, containing a scalar correct count of the prediction + + Examples: + >>> # For a defined network Net without loss function + >>> net = Net() + >>> eval_net = nn.ClassifyCorrectCell(net) + """ + + def __init__(self, network): + super(ClassifyCorrectCell, self).__init__(auto_prefix=False) + self._network = network + self.argmax = P.Argmax() + self.equal = P.Equal() + self.cast = P.Cast() + self.reduce_sum = P.ReduceSum() + self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + + def construct(self, data, label): + outputs = self._network(data) + y_pred = self.argmax(outputs) + y_pred = self.cast(y_pred, mstype.int32) + y_correct = self.equal(y_pred, label) + y_correct = self.cast(y_correct, mstype.float32) + y_correct = self.reduce_sum(y_correct) + total_correct = self.allreduce(y_correct) + return (total_correct,) + + +class DistAccuracy(nn.Metric): + r""" + Calculates the accuracy for classification data in distributed mode. + The accuracy class creates two local variables, correct number and total number that are used to compute the + frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an + idempotent operation that simply divides correct number by total number. + + .. math:: + + \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} + + {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} + + Args: + eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label). + + Examples: + >>> y_correct = Tensor(np.array([20])) + >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) + >>> metric.clear() + >>> metric.update(y_correct) + >>> accuracy = metric.eval() + """ + + def __init__(self, batch_size, device_num): + super(DistAccuracy, self).__init__() + self.clear() + self.batch_size = batch_size + self.device_num = device_num + + def clear(self): + """Clears the internal evaluation result.""" + self._correct_num = 0 + self._total_num = 0 + + def update(self, *inputs): + """ + Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. + + Args: + inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. + `y_correct` is the right prediction count that gathered from all devices + it's a scalar in float type + + Raises: + ValueError: If the number of the input is not 1. + """ + + if len(inputs) != 1: + raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) + y_correct = self._convert_data(inputs[0]) + self._correct_num += y_correct + self._total_num += self.batch_size * self.device_num + + def eval(self): + """ + Computes the accuracy. + + Returns: + Float, the computed result. + + Raises: + RuntimeError: If the sample size is 0. + """ + + if self._total_num == 0: + raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') + return self._correct_num / self._total_num diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index e48a339a02..f875149dba 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -31,9 +31,12 @@ from mindspore.common import set_seed from mindspore.parallel import set_algo_parameters import mindspore.nn as nn import mindspore.common.initializer as weight_init +import mindspore.log as logger from src.lr_generator import get_lr, warmup_cosine_annealing_lr from src.CrossEntropySmooth import CrossEntropySmooth from src.config import cfg +from src.eval_callback import EvalCallBack +from src.metric import DistAccuracy, ClassifyCorrectCell parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--net', type=str, default=None, help='Resnet Model, resnet18, resnet50 or resnet101') @@ -48,6 +51,15 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, help="Filter head weight parameters, default is False.") +parser.add_argument("--run_eval", type=ast.literal_eval, default=False, + help="Run evaluation when training, default is False.") +parser.add_argument('--eval_dataset_path', type=str, default=None, help='Evaluation dataset path when run_eval is True') +parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, + help="Save best checkpoint when run_eval is True, default is True.") +parser.add_argument("--eval_start_epoch", type=int, default=40, + help="Evaluation start epoch when run_eval is True, default is 40.") +parser.add_argument("--eval_interval", type=int, default=1, + help="Evaluation interval when run_eval is True, default is 1.") args_opt = parser.parse_args() set_seed(1) @@ -89,6 +101,12 @@ def filter_checkpoint_parameter_by_list(origin_dict, param_filter): del origin_dict[key] break +def apply_eval(eval_param): + eval_model = eval_param["model"] + eval_ds = eval_param["dataset"] + metrics_name = eval_param["metrics_name"] + res = eval_model.eval(eval_ds) + return res[metrics_name] if __name__ == '__main__': target = args_opt.device_target @@ -185,12 +203,16 @@ if __name__ == '__main__': else: loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False) + dist_eval_network = ClassifyCorrectCell(net) if args_opt.run_distribute else None + metrics = {"acc"} + if args_opt.run_distribute: + metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)} + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, + amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network) if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \ args_opt.parameter_server or target == "CPU": ## fp32 training - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network) if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012": from src.lr_generator import get_thor_damping damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size) @@ -201,6 +223,8 @@ if __name__ == '__main__': loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False, frequency=config.frequency) + args_opt.run_eval = False + logger.warning("Thor optimizer not support evaluation while training.") # define callbacks time_cb = TimeMonitor(data_size=step_size) @@ -211,7 +235,17 @@ if __name__ == '__main__': keep_checkpoint_max=config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) cb += [ckpt_cb] - + if args_opt.run_eval: + if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)): + raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) + eval_dataset = create_dataset(dataset_path=args_opt.eval_dataset_path, do_train=False, + batch_size=config.batch_size, target=target) + eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, + eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=ckpt_save_dir, besk_ckpt_name="best_acc.ckpt", + metrics_name="acc") + cb += [eval_cb] # train model if args_opt.net == "se-resnet50": config.epoch_size = config.train_epoch_size diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md index 3e30df696e..0c4ba269d2 100644 --- a/model_zoo/official/cv/ssd/README.md +++ b/model_zoo/official/cv/ssd/README.md @@ -123,8 +123,8 @@ Dataset used: [COCO2017]() ### Prepare the model -1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`. -2. Change the dataset config in the corresponding config. `src/config_ssd300.py` or `src/config_ssd_mobilenet_v1_fpn.py`. +1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`, `ssd_mobilenet_v1_fpn`, `ssd_resnet50_fpn`. +2. Change the dataset config in the corresponding config. `src/config_ssd300.py`, `src/config_ssd_mobilenet_v1_fpn.py`, `src/config_ssd_resnet50_fpn.py`, `src/config_ssd_vgg16.py`. 3. If you are running with `ssd_mobilenet_v1_fpn`, you need a pretrained model for `mobilenet_v1`. Set the checkpoint path to `feature_extractor_base_param` in `src/config_ssd_mobilenet_v1_fpn.py`. For more detail about training mobilnet_v1, please refer to the mobilenetv1 model. ### Run the scripts @@ -201,6 +201,7 @@ Then you can run everything just like on ascend. ├─ src ├─ __init__.py # init file ├─ box_utils.py # bbox utils + ├─ eval_callback.py # evaluation callback when training ├─ eval_utils.py # metrics utils ├─ config.py # total config ├─ dataset.py # create dataset and process dataset @@ -229,6 +230,10 @@ Then you can run everything just like on ascend. "loss_scale": 1024 # Loss scale "filter_weight": False # Load parameters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True. "freeze_layer": "none" # Freeze the backbone parameters or not, support none and backbone. + "run_eval": False # Run evaluation when training + "save_best_ckpt": True # Save best checkpoint when run_eval is True + "eval_start_epoch": 40 # Evaluation start epoch when run_eval is True + "eval_interval": 1 # valuation interval when run_eval is True "class_num": 81 # Dataset class number "image_shape": [300, 300] # Image height and width used as input to the model @@ -311,6 +316,10 @@ epoch time: 150753.701, per step time: 329.157 ... ``` +#### Evaluation while training + +You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. + #### Transfer Training You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps. diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index ca4bea9cee..aee53937ab 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -17,14 +17,12 @@ import os import argparse -import time -import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 from src.dataset import create_ssd_dataset, create_mindrecord from src.config import config -from src.eval_utils import metrics +from src.eval_utils import apply_eval from src.box_utils import default_boxes def ssd_eval(dataset_path, ckpt_path, anno_json): @@ -50,31 +48,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): load_param_into_net(net, param_dict) net.set_train(False) - i = batch_size total = ds.get_dataset_size() * batch_size - start = time.time() - pred_data = [] print("\n========================================\n") print("total images num: ", total) print("Processing, please wait a moment.") - for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): - img_id = data['img_id'] - img_np = data['image'] - image_shape = data['image_shape'] - - output = net(Tensor(img_np)) - for batch_idx in range(img_np.shape[0]): - pred_data.append({"boxes": output[0].asnumpy()[batch_idx], - "box_scores": output[1].asnumpy()[batch_idx], - "img_id": int(np.squeeze(img_id[batch_idx])), - "image_shape": image_shape[batch_idx]}) - percent = round(i / total * 100., 2) - - print(f' {str(percent)} [{i}/{total}]', end='\r') - i += batch_size - cost_time = int((time.time() - start) * 1000) - print(f' 100% [{total}/{total}] cost {cost_time} ms') - mAP = metrics(pred_data, anno_json) + eval_param_dict = {"net": net, "dataset": ds, "anno_json": anno_json} + mAP = apply_eval(eval_param_dict) print("\n========================================\n") print(f"mAP: {mAP}") diff --git a/model_zoo/official/cv/ssd/src/eval_callback.py b/model_zoo/official/cv/ssd/src/eval_callback.py new file mode 100644 index 0000000000..205fce0eaf --- /dev/null +++ b/model_zoo/official/cv/ssd/src/eval_callback.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) diff --git a/model_zoo/official/cv/ssd/src/eval_utils.py b/model_zoo/official/cv/ssd/src/eval_utils.py index e8e01b32c0..6fcdb8ba92 100644 --- a/model_zoo/official/cv/ssd/src/eval_utils.py +++ b/model_zoo/official/cv/ssd/src/eval_utils.py @@ -16,8 +16,28 @@ import json import numpy as np +from mindspore import Tensor from .config import config +def apply_eval(eval_param_dict): + net = eval_param_dict["net"] + net.set_train(False) + ds = eval_param_dict["dataset"] + anno_json = eval_param_dict["anno_json"] + pred_data = [] + for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): + img_id = data['img_id'] + img_np = data['image'] + image_shape = data['image_shape'] + + output = net(Tensor(img_np)) + for batch_idx in range(img_np.shape[0]): + pred_data.append({"boxes": output[0].asnumpy()[batch_idx], + "box_scores": output[1].asnumpy()[batch_idx], + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx]}) + mAP = metrics(pred_data, anno_json) + return mAP def apply_nms(all_boxes, all_scores, thres, max_boxes): """Apply NMS to bboxes.""" diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index cf182e3792..efee81a2d1 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -15,6 +15,7 @@ """Train SSD and get checkpoint files.""" +import os import argparse import ast import mindspore.nn as nn @@ -25,11 +26,15 @@ from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, dtype -from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 +from src.ssd import SSD300, SsdInferWithDecoder, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2,\ + ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 from src.config import config from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr from src.init_params import init_net_param, filter_checkpoint_parameter_by_list +from src.eval_callback import EvalCallBack +from src.eval_utils import apply_eval +from src.box_utils import default_boxes set_seed(1) @@ -57,6 +62,14 @@ def get_args(): parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"], help="freeze the weights of network, support freeze the backbone's weights, " "default is not freezing.") + parser.add_argument("--run_eval", type=ast.literal_eval, default=False, + help="Run evaluation when training, default is False.") + parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, + help="Save best checkpoint when run_eval is True, default is True.") + parser.add_argument("--eval_start_epoch", type=int, default=40, + help="Evaluation start epoch when run_eval is True, default is 40.") + parser.add_argument("--eval_interval", type=int, default=1, + help="Evaluation interval when run_eval is True, default is 1.") args_opt = parser.parse_args() return args_opt @@ -170,8 +183,25 @@ def main(): config.momentum, config.weight_decay, loss_scale) net = TrainingWrapper(net, opt, loss_scale) - callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] + if args_opt.run_eval: + eval_net = SsdInferWithDecoder(ssd, Tensor(default_boxes), config) + eval_net.set_train(False) + mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False) + eval_dataset = create_ssd_dataset(mindrecord_file, batch_size=args_opt.batch_size, repeat_num=1, + is_training=False, use_multiprocessing=False) + if args_opt.dataset == "coco": + anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) + elif args_opt.dataset == "voc": + anno_json = os.path.join(config.voc_root, config.voc_json) + else: + raise ValueError('SSD eval only support dataset mode is coco and voc!') + eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, + eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt", + metrics_name="mAP") + callback.append(eval_cb) model = Model(net) dataset_sink_mode = False if args_opt.mode == "sink" and args_opt.run_platform != "CPU": diff --git a/model_zoo/official/cv/unet/README.md b/model_zoo/official/cv/unet/README.md index 37de66a5d8..9d413e7f4a 100644 --- a/model_zoo/official/cv/unet/README.md +++ b/model_zoo/official/cv/unet/README.md @@ -128,6 +128,7 @@ Then you can run everything just like on ascend. │ ├──config.py // parameter configuration │ ├──data_loader.py // creating dataset │ ├──loss.py // loss + │ ├──eval_callback.py // evaluation callback while training │ ├──utils.py // General components (callback function) │ ├──unet_medical // Unet medical architecture ├──__init__.py // init file @@ -168,6 +169,11 @@ Parameters for both training and evaluation can be set in config.py 'resume_ckpt': './', # pretrain model path 'transfer_training': False # whether do transfer training 'filter_weight': ["final.weight"] # weight name to filter while doing transfer training + 'run_eval': False # Run evaluation when training + 'save_best_ckpt': True # Save best checkpoint when run_eval is True + 'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True + 'eval_interval': 1 # valuation interval when run_eval is True + ``` - config for Unet++, cell nuclei dataset @@ -193,6 +199,10 @@ Parameters for both training and evaluation can be set in config.py 'resume_ckpt': './', # pretrain model path 'transfer_training': False # whether do transfer training 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # weight name to filter while doing transfer training + 'run_eval': False # Run evaluation when training + 'save_best_ckpt': True # Save best checkpoint when run_eval is True + 'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True + 'eval_interval': 1 # valuation interval when run_eval is True ``` ## [Training Process](#contents) @@ -245,6 +255,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891 step: 300, loss is 0.18949677, fps is 57.63118508760329 ``` +#### Evaluation while training + +You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True. + ## [Evaluation Process](#contents) ### Evaluation diff --git a/model_zoo/official/cv/unet/README_CN.md b/model_zoo/official/cv/unet/README_CN.md index 2f928ea228..e40d3d3f86 100644 --- a/model_zoo/official/cv/unet/README_CN.md +++ b/model_zoo/official/cv/unet/README_CN.md @@ -132,6 +132,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] │ ├──config.py // 参数配置 │ ├──data_loader.py // 数据处理 │ ├──loss.py // 损失函数 + │ ├─ eval_callback.py // 训练时推理回调函数 │ ├──utils.py // 通用组件(回调函数) │ ├──unet_medical // 医学图像处理Unet结构 ├──__init__.py @@ -247,6 +248,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891 step: 300, loss is 0.18949677, fps is 57.63118508760329 ``` +#### 训练时推理 + +训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` 。 + ## 评估过程 ### 评估 diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index a7fad58e28..d043591ff0 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -16,10 +16,6 @@ import os import argparse import logging -import cv2 -import numpy as np -import mindspore.nn as nn -import mindspore.ops.operations as F from mindspore import context, Model from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -27,76 +23,11 @@ from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet -from src.utils import UnetEval +from src.utils import UnetEval, TempLoss, dice_coeff device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) - -class TempLoss(nn.Cell): - """A temp loss cell.""" - def __init__(self): - super(TempLoss, self).__init__() - self.identity = F.identity() - def construct(self, logits, label): - return self.identity(logits) - - -class dice_coeff(nn.Metric): - def __init__(self): - super(dice_coeff, self).__init__() - self.clear() - def clear(self): - self._dice_coeff_sum = 0 - self._iou_sum = 0 - self._samples_num = 0 - - def update(self, *inputs): - if len(inputs) != 2: - raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) - y = self._convert_data(inputs[1]) - self._samples_num += y.shape[0] - y = y.transpose(0, 2, 3, 1) - b, h, w, c = y.shape - if b != 1: - raise ValueError('Batch size should be 1 when in evaluation.') - y = y.reshape((h, w, c)) - if cfg_unet["eval_activate"].lower() == "softmax": - y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) - if cfg_unet["eval_resize"]: - y_pred = [] - for i in range(cfg_unet["num_classes"]): - y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) - y_pred = np.stack(y_pred, axis=-1) - else: - y_pred = y_softmax - elif cfg_unet["eval_activate"].lower() == "argmax": - y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) - y_pred = [] - for i in range(cfg_unet["num_classes"]): - if cfg_unet["eval_resize"]: - y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) - else: - y_pred.append(np.float32(y_argmax == i)) - y_pred = np.stack(y_pred, axis=-1) - else: - raise ValueError('config eval_activate should be softmax or argmax.') - y_pred = y_pred.astype(np.float32) - inter = np.dot(y_pred.flatten(), y.flatten()) - union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) - - single_dice_coeff = 2*float(inter)/float(union+1e-6) - single_iou = single_dice_coeff / (2 - single_dice_coeff) - print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) - self._dice_coeff_sum += single_dice_coeff - self._iou_sum += single_iou - - def eval(self): - if self._samples_num == 0: - raise RuntimeError('Total samples num must not be 0.') - return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) - - def test_net(data_dir, ckpt_path, cross_valid_ind=1, @@ -119,7 +50,7 @@ def test_net(data_dir, else: _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], img_size=cfg['img_size']) - model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()}) + model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)}) print("============== Starting Evaluating ============") eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] diff --git a/model_zoo/official/cv/unet/src/eval_callback.py b/model_zoo/official/cv/unet/src/eval_callback.py new file mode 100644 index 0000000000..205fce0eaf --- /dev/null +++ b/model_zoo/official/cv/unet/src/eval_callback.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) diff --git a/model_zoo/official/cv/unet/src/loss.py b/model_zoo/official/cv/unet/src/loss.py index e1ba7cb9cf..81c2528863 100644 --- a/model_zoo/official/cv/unet/src/loss.py +++ b/model_zoo/official/cv/unet/src/loss.py @@ -41,7 +41,7 @@ class MultiCrossEntropyWithLogits(nn.Cell): def __init__(self): super(MultiCrossEntropyWithLogits, self).__init__() self.loss = CrossEntropyWithLogits() - self.squeeze = F.Squeeze() + self.squeeze = F.Squeeze(axis=0) def construct(self, logits, label): total_loss = 0 diff --git a/model_zoo/official/cv/unet/src/utils.py b/model_zoo/official/cv/unet/src/utils.py index 5285ee5c83..981ed2a181 100644 --- a/model_zoo/official/cv/unet/src/utils.py +++ b/model_zoo/official/cv/unet/src/utils.py @@ -14,6 +14,7 @@ # ============================================================================ import time +import cv2 import numpy as np from PIL import Image from mindspore import nn @@ -25,20 +26,100 @@ class UnetEval(nn.Cell): """ Add Unet evaluation activation. """ - def __init__(self, net): + def __init__(self, net, need_slice=False): super(UnetEval, self).__init__() self.net = net + self.need_slice = need_slice self.transpose = ops.Transpose() self.softmax = ops.Softmax(axis=-1) self.argmax = ops.Argmax(axis=-1) + self.squeeze = ops.Squeeze(axis=0) def construct(self, x): out = self.net(x) + if self.need_slice: + out = self.squeeze(out[-1:]) out = self.transpose(out, (0, 2, 3, 1)) softmax_out = self.softmax(out) argmax_out = self.argmax(out) return (softmax_out, argmax_out) +class TempLoss(nn.Cell): + """A temp loss cell.""" + def __init__(self): + super(TempLoss, self).__init__() + self.identity = ops.identity() + def construct(self, logits, label): + return self.identity(logits) + +def apply_eval(eval_param_dict): + """run Evaluation""" + model = eval_param_dict["model"] + dataset = eval_param_dict["dataset"] + metrics_name = eval_param_dict["metrics_name"] + index = 0 if metrics_name == "dice_coeff" else 1 + eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name][index] + return eval_score + +class dice_coeff(nn.Metric): + """Unet Metric, return dice coefficient and IOU.""" + def __init__(self, cfg_unet, print_res=True): + super(dice_coeff, self).__init__() + self.clear() + self.cfg_unet = cfg_unet + self.print_res = print_res + + def clear(self): + self._dice_coeff_sum = 0 + self._iou_sum = 0 + self._samples_num = 0 + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) + y = self._convert_data(inputs[1]) + self._samples_num += y.shape[0] + y = y.transpose(0, 2, 3, 1) + b, h, w, c = y.shape + if b != 1: + raise ValueError('Batch size should be 1 when in evaluation.') + y = y.reshape((h, w, c)) + if self.cfg_unet["eval_activate"].lower() == "softmax": + y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) + if self.cfg_unet["eval_resize"]: + y_pred = [] + for i in range(self.cfg_unet["num_classes"]): + y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) + y_pred = np.stack(y_pred, axis=-1) + else: + y_pred = y_softmax + elif self.cfg_unet["eval_activate"].lower() == "argmax": + y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) + y_pred = [] + for i in range(self.cfg_unet["num_classes"]): + if self.cfg_unet["eval_resize"]: + y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) + else: + y_pred.append(np.float32(y_argmax == i)) + y_pred = np.stack(y_pred, axis=-1) + else: + raise ValueError('config eval_activate should be softmax or argmax.') + y_pred = y_pred.astype(np.float32) + inter = np.dot(y_pred.flatten(), y.flatten()) + union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) + + single_dice_coeff = 2 * float(inter) / float(union+1e-6) + single_iou = single_dice_coeff / (2 - single_dice_coeff) + if self.print_res: + print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) + self._dice_coeff_sum += single_dice_coeff + self._iou_sum += single_iou + + def eval(self): + if self._samples_num == 0: + raise RuntimeError('Total samples num must not be 0.') + return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) + class StepLossTimeMonitor(Callback): def __init__(self, batch_size, per_print_times=1): diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index c0d8362bb1..252dd2cb3f 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -30,23 +30,25 @@ from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits -from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list +from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff from src.config import cfg_unet +from src.eval_callback import EvalCallBack device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) mindspore.set_seed(1) -def train_net(data_dir, +def train_net(args_opt, cross_valid_ind=1, epochs=400, batch_size=16, lr=0.0001, - run_distribute=False, cfg=None): rank = 0 group_size = 1 + data_dir = args_opt.data_url + run_distribute = args_opt.run_distribute if run_distribute: init() group_size = get_group_size() @@ -55,12 +57,13 @@ def train_net(data_dir, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=False) - + need_slice = False if cfg['model'] == 'unet_medical': net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) elif cfg['model'] == 'unet_nested': net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) + need_slice = cfg['use_ds'] elif cfg['model'] == 'unet_simple': net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) else: @@ -83,12 +86,15 @@ def train_net(data_dir, train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, is_train=True, augment=True, split=0.8, rank=rank, group_size=group_size) + valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, + eval_resize=cfg["eval_resize"], split=0.8, + python_multiprocessing=False) else: repeat = epochs dataset_sink_mode = False per_print_times = 1 - train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute, - cfg["crop"], cfg['img_size']) + train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, + run_distribute, cfg["crop"], cfg['img_size']) train_data_size = train_dataset.get_dataset_size() print("dataset length is:", train_data_size) ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, @@ -106,6 +112,15 @@ def train_net(data_dir, print("============== Starting Training ==============") callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] + if args_opt.run_eval: + eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(), + metrics={"dice_coeff": dice_coeff(cfg_unet, False)}) + eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, + eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, + ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt", + metrics_name=args_opt.eval_metrics) + callbacks.append(eval_cb) model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) print("============== End Training ==============") @@ -117,6 +132,17 @@ def get_args(): help='data directory') parser.add_argument('-t', '--run_distribute', type=ast.literal_eval, default=False, help='Run distribute, default: false.') + parser.add_argument("--run_eval", type=ast.literal_eval, default=False, + help="Run evaluation when training, default is False.") + parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, + help="Save best checkpoint when run_eval is True, default is True.") + parser.add_argument("--eval_start_epoch", type=int, default=0, + help="Evaluation start epoch when run_eval is True, default is 0.") + parser.add_argument("--eval_interval", type=int, default=1, + help="Evaluation interval when run_eval is True, default is 1.") + parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"), + help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], " + "default is dice_coeff.") return parser.parse_args() @@ -127,10 +153,9 @@ if __name__ == '__main__': print("Training setting:", args) epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs'] - train_net(data_dir=args.data_url, + train_net(args_opt=args, cross_valid_ind=cfg_unet['cross_valid_ind'], epochs=epoch_size, batch_size=cfg_unet['batchsize'], lr=cfg_unet['lr'], - run_distribute=args.run_distribute, cfg=cfg_unet)