forked from OSSInnovation/mindspore
modelzoo_unet fix training continue problem.
This commit is contained in:
parent
5c2597850b
commit
17476ddfa3
|
@ -19,7 +19,6 @@
|
|||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer Learning](#transfer-learning)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
|
@ -130,6 +129,8 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'weight_decay': 0.0005, # weight decay value
|
||||
'loss_scale': 1024.0, # loss scale
|
||||
'FixedLossScaleManager': 1024.0, # fix loss scale
|
||||
'resume': False, # whether training with pretrain model
|
||||
'resume_ckpt': './', # pretrain model path
|
||||
```
|
||||
|
||||
|
||||
|
@ -260,8 +261,42 @@ If you need to use the trained model to perform inference on multiple hardware p
|
|||
print("============== Cross valid dice coeff is:", dice_score)
|
||||
```
|
||||
|
||||
### Transfer Learning
|
||||
To be added.
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```
|
||||
# Define model
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
# Continue training if set 'resume' to be True
|
||||
if cfg['resume']:
|
||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# Load dataset
|
||||
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
|
||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
||||
loss_scale=cfg['loss_scale'])
|
||||
criterion = CrossEntropyWithLogits()
|
||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
||||
|
||||
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
|
||||
# Set callbacks
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
|
||||
dataset_sink_mode=False)
|
||||
print("============== End Training ==============")
|
||||
```
|
||||
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
|
|
@ -27,4 +27,7 @@ cfg_unet = {
|
|||
'weight_decay': 0.0005,
|
||||
'loss_scale': 1024.0,
|
||||
'FixedLossScaleManager': 1024.0,
|
||||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore import Model, context
|
|||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.unet import UNet
|
||||
from src.data_loader import create_dataset
|
||||
|
@ -54,6 +55,10 @@ def train_net(data_dir,
|
|||
gradients_mean=False)
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
|
||||
if cfg['resume']:
|
||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
criterion = CrossEntropyWithLogits()
|
||||
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue