From c929dab3ddb02e0e8c08079c2172ecc49f83e1af Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 9 Jun 2020 11:38:07 +0800 Subject: [PATCH] 1:modify shell for deeplabv3 2:fix normalize bug 3:add ci test3:add ci test3:add ci test --- model_zoo/deeplabv3/README.md | 16 ++-- .../deeplabv3/{evaluation.py => eval.py} | 6 +- .../deeplabv3/scripts/run_distribute_train.sh | 24 ++--- model_zoo/deeplabv3/scripts/run_eval.sh | 10 +- .../deeplabv3/scripts/run_standalone_train.sh | 18 ++-- model_zoo/deeplabv3/src/config.py | 7 +- model_zoo/deeplabv3/src/md_dataset.py | 13 +-- .../deeplabv3/src/utils/custom_transforms.py | 1 + model_zoo/deeplabv3/train.py | 18 ++-- .../deeplabv3/run_deeplabv3_ci.sh | 47 +++++++++ .../deeplabv3/train_one_epoch_with_loss.py | 96 +++++++++++++++++++ 11 files changed, 201 insertions(+), 55 deletions(-) rename model_zoo/deeplabv3/{evaluation.py => eval.py} (85%) create mode 100644 tests/st/model_zoo_tests/deeplabv3/run_deeplabv3_ci.sh create mode 100644 tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py diff --git a/model_zoo/deeplabv3/README.md b/model_zoo/deeplabv3/README.md index b178a83e6da..c8df3dab8dd 100644 --- a/model_zoo/deeplabv3/README.md +++ b/model_zoo/deeplabv3/README.md @@ -16,17 +16,17 @@ This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpo - Set options in config.py. - Run `run_standalone_train.sh` for non-distributed training. ``` bash - sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR + sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH ``` - Run `run_distribute_train.sh` for distributed training. ``` bash - sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH + sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH ``` ### Evaluation Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path. - Run run_eval.sh for evaluation. ``` bash - sh scripts/run_eval.sh DEVICE_ID DATA_DIR + sh scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH ``` ## Options and Parameters @@ -49,6 +49,11 @@ config.py: decoder_output_stride The ratio of input to output spatial resolution when employing decoder to refine segmentation results, default is None. image_pyramid Input scales for multi-scale feature extraction, default is None. + epoch_size Epoch size, default is 6. + batch_size batch size of input dataset: N, default is 2. + enable_save_ckpt Enable save checkpoint, default is true. + save_checkpoint_steps Save checkpoint steps, default is 1000. + save_checkpoint_num Save checkpoint numbers, default is 1. ``` @@ -56,11 +61,6 @@ config.py: ``` Parameters for dataset and network: distribute Run distribute, default is false. - epoch_size Epoch size, default is 6. - batch_size batch size of input dataset: N, default is 2. data_url Train/Evaluation data url, required. checkpoint_url Checkpoint path, default is None. - enable_save_ckpt Enable save checkpoint, default is true. - save_checkpoint_steps Save checkpoint steps, default is 1000. - save_checkpoint_num Save checkpoint numbers, default is 1. ``` \ No newline at end of file diff --git a/model_zoo/deeplabv3/evaluation.py b/model_zoo/deeplabv3/eval.py similarity index 85% rename from model_zoo/deeplabv3/evaluation.py rename to model_zoo/deeplabv3/eval.py index e54b2d717bb..7e435719827 100644 --- a/model_zoo/deeplabv3/evaluation.py +++ b/model_zoo/deeplabv3/eval.py @@ -25,9 +25,7 @@ from src.config import config parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") -parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.') parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url') parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') @@ -39,8 +37,8 @@ print(args_opt) if __name__ == "__main__": args_opt.crop_size = config.crop_size args_opt.base_size = config.crop_size - eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval") - net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + eval_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="eval") + net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) diff --git a/model_zoo/deeplabv3/scripts/run_distribute_train.sh b/model_zoo/deeplabv3/scripts/run_distribute_train.sh index 514b0229af0..4dcd8d9768f 100644 --- a/model_zoo/deeplabv3/scripts/run_distribute_train.sh +++ b/model_zoo/deeplabv3/scripts/run_distribute_train.sh @@ -16,17 +16,21 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH" -echo "for example: bash run_distribute_train.sh 8 40 /path/zh-wiki/ /path/hccl.json" +echo "bash run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH" +echo "for example: bash run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH [PRETRAINED_CKPT_PATH](option)" echo "It is better to use absolute path." echo "==============================================================================================================" -EPOCH_SIZE=$2 -DATA_DIR=$3 +DATA_DIR=$2 -export MINDSPORE_HCCL_CONFIG_PATH=$4 -export RANK_TABLE_FILE=$4 -export RANK_SIZE=$1 +export MINDSPORE_HCCL_CONFIG_PATH=$1 +export RANK_TABLE_FILE=$1 +export RANK_SIZE=8 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi cores=`cat /proc/cpuinfo|grep "processor" |wc -l` echo "the number of logical core" $cores avg_core_per_rank=`expr $cores \/ $RANK_SIZE` @@ -55,12 +59,8 @@ do env > env.log taskset -c $cmdopt python ../train.py \ --distribute="true" \ - --epoch_size=$EPOCH_SIZE \ --device_id=$DEVICE_ID \ - --enable_save_ckpt="true" \ - --checkpoint_url="" \ - --save_checkpoint_steps=10000 \ - --save_checkpoint_num=1 \ + --checkpoint_url=$PATH_CHECKPOINT \ --data_url=$DATA_DIR > log.txt 2>&1 & cd ../ done \ No newline at end of file diff --git a/model_zoo/deeplabv3/scripts/run_eval.sh b/model_zoo/deeplabv3/scripts/run_eval.sh index 2470138c335..735dce4cbe5 100644 --- a/model_zoo/deeplabv3/scripts/run_eval.sh +++ b/model_zoo/deeplabv3/scripts/run_eval.sh @@ -15,18 +15,20 @@ # ============================================================================ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "bash run_eval.sh DEVICE_ID DATA_DIR" -echo "for example: bash run_eval.sh /path/zh-wiki/ " +echo "bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH" +echo "for example: bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH" echo "==============================================================================================================" DEVICE_ID=$1 DATA_DIR=$2 +PATH_CHECKPOINT=$3 + mkdir -p ms_log CUR_DIR=`pwd` export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 -python evaluation.py \ +python eval.py \ --device_id=$DEVICE_ID \ - --checkpoint_url="" \ + --checkpoint_url=$PATH_CHECKPOINT \ --data_url=$DATA_DIR > log.txt 2>&1 & \ No newline at end of file diff --git a/model_zoo/deeplabv3/scripts/run_standalone_train.sh b/model_zoo/deeplabv3/scripts/run_standalone_train.sh index 1b84f9d5837..6f5e8dbe52a 100644 --- a/model_zoo/deeplabv3/scripts/run_standalone_train.sh +++ b/model_zoo/deeplabv3/scripts/run_standalone_train.sh @@ -15,13 +15,17 @@ # ============================================================================ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR" -echo "for example: bash run_standalone_train.sh 0 40 /path/zh-wiki/ " +echo "bash run_standalone_pretrain.sh DEVICE_ID DATA_PATH" +echo "for example: bash run_standalone_train.sh DEVICE_ID DATA_PATH [PRETRAINED_CKPT_PATH](option)" echo "==============================================================================================================" DEVICE_ID=$1 -EPOCH_SIZE=$2 -DATA_DIR=$3 +DATA_DIR=$2 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi mkdir -p ms_log CUR_DIR=`pwd` @@ -29,10 +33,6 @@ export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 python train.py \ --distribute="false" \ - --epoch_size=$EPOCH_SIZE \ --device_id=$DEVICE_ID \ - --enable_save_ckpt="true" \ - --checkpoint_url="" \ - --save_checkpoint_steps=10000 \ - --save_checkpoint_num=1 \ + --checkpoint_url=$PATH_CHECKPOINT \ --data_url=$DATA_DIR > log.txt 2>&1 & \ No newline at end of file diff --git a/model_zoo/deeplabv3/src/config.py b/model_zoo/deeplabv3/src/config.py index c3b73e10972..6b5519e46cc 100644 --- a/model_zoo/deeplabv3/src/config.py +++ b/model_zoo/deeplabv3/src/config.py @@ -29,5 +29,10 @@ config = ed({ "fine_tune_batch_norm": False, "ignore_label": 255, "decoder_output_stride": None, - "seg_num_classes": 21 + "seg_num_classes": 21, + "epoch_size": 6, + "batch_size": 2, + "enable_save_ckpt": True, + "save_checkpoint_steps": 10000, + "save_checkpoint_num": 1 }) diff --git a/model_zoo/deeplabv3/src/md_dataset.py b/model_zoo/deeplabv3/src/md_dataset.py index 37b57d10335..e136da23e13 100644 --- a/model_zoo/deeplabv3/src/md_dataset.py +++ b/model_zoo/deeplabv3/src/md_dataset.py @@ -16,6 +16,7 @@ from PIL import Image import mindspore.dataset as de import mindspore.dataset.transforms.vision.c_transforms as C +import numpy as np from .ei_dataset import HwVocRawDataset from .utils import custom_transforms as tr @@ -52,8 +53,8 @@ class DataTransform: rhf_tr = tr.RandomHorizontalFlip() image, label = rhf_tr(image, label) - nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) - image, label = nor_tr(image, label) + image = np.array(image).astype(np.float32) + label = np.array(label).astype(np.float32) return image, label @@ -71,13 +72,13 @@ class DataTransform: fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size) image, label = fsc_tr(image, label) - nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) - image, label = nor_tr(image, label) + image = np.array(image).astype(np.float32) + label = np.array(label).astype(np.float32) return image, label -def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"): +def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True): """ Create Dataset for DeepLabV3. @@ -106,7 +107,7 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"): # 1464 samples / batch_size 8 = 183 batches # epoch_num is num of steps # 3658 steps / 183 = 20 epochs - if usage == "train": + if usage == "train" and shuffle: dataset = dataset.shuffle(1464) dataset = dataset.batch(batch_size, drop_remainder=(usage == "train")) dataset = dataset.repeat(count=epoch_num) diff --git a/model_zoo/deeplabv3/src/utils/custom_transforms.py b/model_zoo/deeplabv3/src/utils/custom_transforms.py index 3473f7eef53..75c78e12409 100644 --- a/model_zoo/deeplabv3/src/utils/custom_transforms.py +++ b/model_zoo/deeplabv3/src/utils/custom_transforms.py @@ -33,6 +33,7 @@ class Normalize: def __call__(self, img, mask): img = np.array(img).astype(np.float32) mask = np.array(mask).astype(np.float32) + img = ((img - self.mean) / self.std).astype(np.float32) return img, mask diff --git a/model_zoo/deeplabv3/train.py b/model_zoo/deeplabv3/train.py index 2135b0abf55..d0966139771 100644 --- a/model_zoo/deeplabv3/train.py +++ b/model_zoo/deeplabv3/train.py @@ -27,14 +27,10 @@ from src.config import config parser = argparse.ArgumentParser(description="Deeplabv3 training") parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") -parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.') -parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') parser.add_argument('--data_url', required=True, default=None, help='Train data url') parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') -parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") -parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.") -parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") + args_opt = parser.parse_args() print(args_opt) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) @@ -70,16 +66,16 @@ if __name__ == "__main__": init() args_opt.base_size = config.crop_size args_opt.crop_size = config.crop_size - train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train") + train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train") dataset_size = train_dataset.get_dataset_size() time_cb = TimeMonitor(data_size=dataset_size) callback = [time_cb, LossCallBack()] - if args_opt.enable_save_ckpt == "true": - config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, - keep_checkpoint_max=args_opt.save_checkpoint_num) + if config.enable_save_ckpt: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + keep_checkpoint_max=config.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) callback.append(ckpoint_cb) - net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) @@ -88,5 +84,5 @@ if __name__ == "__main__": loss = OhemLoss(config.seg_num_classes, config.ignore_label) opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) model = Model(net, loss, opt) - model.train(args_opt.epoch_size, train_dataset, callback) + model.train(config.epoch_size, train_dataset, callback) \ No newline at end of file diff --git a/tests/st/model_zoo_tests/deeplabv3/run_deeplabv3_ci.sh b/tests/st/model_zoo_tests/deeplabv3/run_deeplabv3_ci.sh new file mode 100644 index 00000000000..df24367417a --- /dev/null +++ b/tests/st/model_zoo_tests/deeplabv3/run_deeplabv3_ci.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "for example: bash run_deeplabv3_ci.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH" +echo "==============================================================================================================" +DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 +BASE_PATH=$(cd "$(dirname $0)"; pwd) +unset SLOG_PRINT_TO_STDOUT +CODE_DIR="./" +if [ -d ${BASE_PATH}/../../../../model_zoo/deeplabv3 ]; then + CODE_DIR=${BASE_PATH}/../../../../model_zoo/deeplabv3 +elif [ -d ${BASE_PATH}/../../model_zoo/deeplabv3 ]; then + CODE_DIR=${BASE_PATH}/../../model_zoo/deeplabv3 +else + echo "[ERROR] code dir is not found" +fi +echo $CODE_DIR +rm -rf ${BASE_PATH}/deeplabv3 +cp -r ${CODE_DIR} ${BASE_PATH}/deeplabv3 +cp -f ${BASE_PATH}/train_one_epoch_with_loss.py ${BASE_PATH}/deeplabv3/train_one_epoch_with_loss.py +cd ${BASE_PATH}/deeplabv3 +python train_one_epoch_with_loss.py --data_url=$DATA_DIR --checkpoint_url=$PATH_CHECKPOINT --device_id=$DEVICE_ID > train_deeplabv3_ci.log 2>&1 & +process_pid=`echo $!` +wait ${process_pid} +status=`echo $?` +if [ "${status}" != "0" ]; then + echo "[ERROR] test deeplabv3 failed. status: ${status}" + exit 1 +else + echo "[INFO] test deeplabv3 success." +fi \ No newline at end of file diff --git a/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py b/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py new file mode 100644 index 00000000000..73931a8046b --- /dev/null +++ b/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train.""" +import argparse +import time +from mindspore import context +from mindspore.nn.optim.momentum import Momentum +from mindspore import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import Callback +from src.md_dataset import create_dataset +from src.losses import OhemLoss +from src.deeplabv3 import deeplabv3_resnet50 +from src.config import config +parser = argparse.ArgumentParser(description="Deeplabv3 training") +parser.add_argument('--data_url', required=True, default=None, help='Train data url') +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') +args_opt = parser.parse_args() +print(args_opt) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + +class LossCallBack(Callback): + """ + Monitor the loss in training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, data_size, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self.data_size = data_size + self._per_print_times = per_print_times + self.time = 1000 + self.loss = 0 + def epoch_begin(self, run_context): + self.epoch_time = time.time() + def step_end(self, run_context): + cb_params = run_context.original_args() + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + self.time = epoch_mseconds / self.data_size + self.loss += cb_params.net_outputs + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) + +def model_fine_tune(flags, train_net, fix_weight_layer): + checkpoint_path = flags.checkpoint_url + if checkpoint_path is None: + return + param_dict = load_checkpoint(checkpoint_path) + load_param_into_net(train_net, param_dict) + for para in train_net.trainable_params(): + if fix_weight_layer in para.name: + para.requires_grad = False + +if __name__ == "__main__": + start_time = time.time() + epoch_size = 3 + args_opt.base_size = config.crop_size + args_opt.crop_size = config.crop_size + train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size, + usage="train", shuffle=False) + dataset_size = train_dataset.get_dataset_size() + callback = LossCallBack(dataset_size) + net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, + decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, + fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) + net.set_train() + model_fine_tune(args_opt, net, 'layer') + loss = OhemLoss(config.seg_num_classes, config.ignore_label) + opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) + model = Model(net, loss, opt) + model.train(epoch_size, train_dataset, callback) + print(time.time() - start_time) + print("expect loss: ", callback.loss / 3) + print("expect time: ", callback.time) + expect_loss = 0.5 + expect_time = 35 + assert callback.loss.asnumpy() / 3 <= expect_loss + assert callback.time <= expect_time