diff --git a/model_zoo/official/cv/unet/README.md b/model_zoo/official/cv/unet/README.md index 03d5a569ae..c50a5088b1 100644 --- a/model_zoo/official/cv/unet/README.md +++ b/model_zoo/official/cv/unet/README.md @@ -19,7 +19,6 @@ - [How to use](#how-to-use) - [Inference](#inference) - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) - - [Transfer Learning](#transfer-learning) - [Description of Random Situation](#description-of-random-situation) - [ModelZoo Homepage](#modelzoo-homepage) @@ -130,6 +129,8 @@ Parameters for both training and evaluation can be set in config.py 'weight_decay': 0.0005, # weight decay value 'loss_scale': 1024.0, # loss scale 'FixedLossScaleManager': 1024.0, # fix loss scale + 'resume': False, # whether training with pretrain model + 'resume_ckpt': './', # pretrain model path ``` @@ -260,8 +261,42 @@ If you need to use the trained model to perform inference on multiple hardware p print("============== Cross valid dice coeff is:", dice_score) ``` -### Transfer Learning -To be added. +### Continue Training on the Pretrained Model + +- running on Ascend + + ``` + # Define model + net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + # Continue training if set 'resume' to be True + if cfg['resume']: + param_dict = load_checkpoint(cfg['resume_ckpt']) + load_param_into_net(net, param_dict) + + # Load dataset + train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) + train_data_size = train_dataset.get_dataset_size() + + optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'], + loss_scale=cfg['loss_scale']) + criterion = CrossEntropyWithLogits() + loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False) + + model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") + + + # Set callbacks + ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, + keep_checkpoint_max=cfg['keep_checkpoint_max']) + ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam', + directory='./ckpt_{}/'.format(device_id), + config=ckpt_config) + + print("============== Starting Training ==============") + model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb], + dataset_sink_mode=False) + print("============== End Training ==============") + ``` # [Description of Random Situation](#contents) diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py index 15bd904c0e..6946eafca3 100644 --- a/model_zoo/official/cv/unet/src/config.py +++ b/model_zoo/official/cv/unet/src/config.py @@ -27,4 +27,7 @@ cfg_unet = { 'weight_decay': 0.0005, 'loss_scale': 1024.0, 'FixedLossScaleManager': 1024.0, + + 'resume': False, + 'resume_ckpt': './', } diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index 3bce28a9d7..dedc94899e 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -24,6 +24,7 @@ from mindspore import Model, context from mindspore.communication.management import init, get_group_size from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.unet import UNet from src.data_loader import create_dataset @@ -54,6 +55,10 @@ def train_net(data_dir, gradients_mean=False) net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + if cfg['resume']: + param_dict = load_checkpoint(cfg['resume_ckpt']) + load_param_into_net(net, param_dict) + criterion = CrossEntropyWithLogits() train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) train_data_size = train_dataset.get_dataset_size()