forked from OSSInnovation/mindspore
modelzoo_unet fix training continue problem.
This commit is contained in:
parent
5c2597850b
commit
17476ddfa3
|
@ -19,7 +19,6 @@
|
||||||
- [How to use](#how-to-use)
|
- [How to use](#how-to-use)
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||||
- [Transfer Learning](#transfer-learning)
|
|
||||||
- [Description of Random Situation](#description-of-random-situation)
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
@ -130,6 +129,8 @@ Parameters for both training and evaluation can be set in config.py
|
||||||
'weight_decay': 0.0005, # weight decay value
|
'weight_decay': 0.0005, # weight decay value
|
||||||
'loss_scale': 1024.0, # loss scale
|
'loss_scale': 1024.0, # loss scale
|
||||||
'FixedLossScaleManager': 1024.0, # fix loss scale
|
'FixedLossScaleManager': 1024.0, # fix loss scale
|
||||||
|
'resume': False, # whether training with pretrain model
|
||||||
|
'resume_ckpt': './', # pretrain model path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,8 +261,42 @@ If you need to use the trained model to perform inference on multiple hardware p
|
||||||
print("============== Cross valid dice coeff is:", dice_score)
|
print("============== Cross valid dice coeff is:", dice_score)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Transfer Learning
|
### Continue Training on the Pretrained Model
|
||||||
To be added.
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```
|
||||||
|
# Define model
|
||||||
|
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||||
|
# Continue training if set 'resume' to be True
|
||||||
|
if cfg['resume']:
|
||||||
|
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||||
|
train_data_size = train_dataset.get_dataset_size()
|
||||||
|
|
||||||
|
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
||||||
|
loss_scale=cfg['loss_scale'])
|
||||||
|
criterion = CrossEntropyWithLogits()
|
||||||
|
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
||||||
|
|
||||||
|
|
||||||
|
# Set callbacks
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||||
|
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
|
||||||
|
directory='./ckpt_{}/'.format(device_id),
|
||||||
|
config=ckpt_config)
|
||||||
|
|
||||||
|
print("============== Starting Training ==============")
|
||||||
|
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
|
||||||
|
dataset_sink_mode=False)
|
||||||
|
print("============== End Training ==============")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
# [Description of Random Situation](#contents)
|
# [Description of Random Situation](#contents)
|
||||||
|
|
|
@ -27,4 +27,7 @@ cfg_unet = {
|
||||||
'weight_decay': 0.0005,
|
'weight_decay': 0.0005,
|
||||||
'loss_scale': 1024.0,
|
'loss_scale': 1024.0,
|
||||||
'FixedLossScaleManager': 1024.0,
|
'FixedLossScaleManager': 1024.0,
|
||||||
|
|
||||||
|
'resume': False,
|
||||||
|
'resume_ckpt': './',
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ from mindspore import Model, context
|
||||||
from mindspore.communication.management import init, get_group_size
|
from mindspore.communication.management import init, get_group_size
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
from src.unet import UNet
|
from src.unet import UNet
|
||||||
from src.data_loader import create_dataset
|
from src.data_loader import create_dataset
|
||||||
|
@ -54,6 +55,10 @@ def train_net(data_dir,
|
||||||
gradients_mean=False)
|
gradients_mean=False)
|
||||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||||
|
|
||||||
|
if cfg['resume']:
|
||||||
|
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
criterion = CrossEntropyWithLogits()
|
criterion = CrossEntropyWithLogits()
|
||||||
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||||
train_data_size = train_dataset.get_dataset_size()
|
train_data_size = train_dataset.get_dataset_size()
|
||||||
|
|
Loading…
Reference in New Issue