modelzoo_unet fix training continue problem.

This commit is contained in:
zhanghuiyao 2020-09-19 10:15:47 +08:00
parent 5c2597850b
commit 17476ddfa3
3 changed files with 46 additions and 3 deletions

View File

@ -19,7 +19,6 @@
- [How to use](#how-to-use) - [How to use](#how-to-use)
- [Inference](#inference) - [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)
@ -130,6 +129,8 @@ Parameters for both training and evaluation can be set in config.py
'weight_decay': 0.0005, # weight decay value 'weight_decay': 0.0005, # weight decay value
'loss_scale': 1024.0, # loss scale 'loss_scale': 1024.0, # loss scale
'FixedLossScaleManager': 1024.0, # fix loss scale 'FixedLossScaleManager': 1024.0, # fix loss scale
'resume': False, # whether training with pretrain model
'resume_ckpt': './', # pretrain model path
``` ```
@ -260,8 +261,42 @@ If you need to use the trained model to perform inference on multiple hardware p
print("============== Cross valid dice coeff is:", dice_score) print("============== Cross valid dice coeff is:", dice_score)
``` ```
### Transfer Learning ### Continue Training on the Pretrained Model
To be added.
- running on Ascend
```
# Define model
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
# Continue training if set 'resume' to be True
if cfg['resume']:
param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict)
# Load dataset
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
train_data_size = train_dataset.get_dataset_size()
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
loss_scale=cfg['loss_scale'])
criterion = CrossEntropyWithLogits()
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
# Set callbacks
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
directory='./ckpt_{}/'.format(device_id),
config=ckpt_config)
print("============== Starting Training ==============")
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
dataset_sink_mode=False)
print("============== End Training ==============")
```
# [Description of Random Situation](#contents) # [Description of Random Situation](#contents)

View File

@ -27,4 +27,7 @@ cfg_unet = {
'weight_decay': 0.0005, 'weight_decay': 0.0005,
'loss_scale': 1024.0, 'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0, 'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
} }

View File

@ -24,6 +24,7 @@ from mindspore import Model, context
from mindspore.communication.management import init, get_group_size from mindspore.communication.management import init, get_group_size
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet import UNet from src.unet import UNet
from src.data_loader import create_dataset from src.data_loader import create_dataset
@ -54,6 +55,10 @@ def train_net(data_dir,
gradients_mean=False) gradients_mean=False)
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
if cfg['resume']:
param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits() criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
train_data_size = train_dataset.get_dataset_size() train_data_size = train_dataset.get_dataset_size()