forked from mindspore-Ecosystem/mindspore
!21904 【模型开发】r1.3分支resnet50增加多卡训练、ckpt保存策略、图编译相关动作
Merge pull request !21904 from Atlas_hrp/r1.3
This commit is contained in:
commit
97aaeede7f
|
@ -265,6 +265,10 @@ Parameters for both training and evaluation can be set in config file.
|
|||
"lr_init": 0.01, # initial learning rate
|
||||
"lr_end": 0.00001, # final learning rate
|
||||
"lr_max": 0.1, # maximum learning rate
|
||||
"save_graphs": False, # save graph results
|
||||
"save_graphs_path": "./graphs", # save graph results path
|
||||
"has_trained_epoch":0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus has_trained_epoch
|
||||
"has_trained_step":0, # step size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to step_size minus has_trained_step
|
||||
```
|
||||
|
||||
- Config for ResNet18 and ResNet50, ImageNet2012 dataset
|
||||
|
@ -287,6 +291,11 @@ Parameters for both training and evaluation can be set in config file.
|
|||
"lr_init": 0, # initial learning rate
|
||||
"lr_max": 0.8, # maximum learning rate
|
||||
"lr_end": 0.0, # minimum learning rate
|
||||
"save_graphs": False, # save graph results
|
||||
"save_graphs_path": "./graphs", # save graph results path
|
||||
"has_trained_epoch":0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus has_trained_epoch
|
||||
"has_trained_step":0, # step size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to step_size minus has_trained_step
|
||||
|
||||
```
|
||||
|
||||
- Config for ResNet34, ImageNet2012 dataset
|
||||
|
@ -309,6 +318,10 @@ Parameters for both training and evaluation can be set in config file.
|
|||
"lr_init": 0, # initial learning rate
|
||||
"lr_max": 1.0, # maximum learning rate
|
||||
"lr_end": 0.0, # minimum learning rate
|
||||
"save_graphs": False, # save graph results
|
||||
"save_graphs_path": "./graphs", # save graph results path
|
||||
"has_trained_epoch":0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus has_trained_epoch
|
||||
"has_trained_step":0, # step size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to step_size minus has_trained_step
|
||||
```
|
||||
|
||||
- Config for ResNet101, ImageNet2012 dataset
|
||||
|
@ -329,6 +342,10 @@ Parameters for both training and evaluation can be set in config file.
|
|||
"use_label_smooth": True, # label_smooth
|
||||
"label_smooth_factor": 0.1, # label_smooth_factor
|
||||
"lr": 0.1 # base learning rate
|
||||
"save_graphs": False, # save graph results
|
||||
"save_graphs_path": "./graphs", # save graph results path
|
||||
"has_trained_epoch":0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus has_trained_epoch
|
||||
"has_trained_step":0, # step size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to step_size minus has_trained_step
|
||||
```
|
||||
|
||||
- Config for SE-ResNet50, ImageNet2012 dataset
|
||||
|
@ -352,6 +369,10 @@ Parameters for both training and evaluation can be set in config file.
|
|||
"lr_init": 0.0, # initial learning rate
|
||||
"lr_max": 0.3, # maximum learning rate
|
||||
"lr_end": 0.0001, # end learning rate
|
||||
"save_graphs": False, # save graph results
|
||||
"save_graphs_path": "./graphs", # save graph results path
|
||||
"has_trained_epoch":0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus has_trained_epoch
|
||||
"has_trained_step":0, # step size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to step_size minus has_trained_step
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
@ -440,6 +461,20 @@ By default, a standalone cache server would be started to cache all eval images
|
|||
|
||||
Users can choose to shutdown the cache server after training or leave it alone for future usage.
|
||||
|
||||
## [Resume Process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
```text
|
||||
# distributed training
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [CONFIG_PATH] [PRETRAINED_CKPT_PATH]
|
||||
|
||||
# standalone training
|
||||
用法:bash run_standalone_train.sh [DATASET_PATH] [CONFIG_PATH] [PRETRAINED_CKPT_PATH]
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
- Training ResNet18 with CIFAR-10 dataset
|
||||
|
|
|
@ -244,8 +244,12 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH]
|
|||
"warmup_epochs":5, # 热身周期数
|
||||
"lr_decay_mode":"poly” # 衰减模式可为步骤、策略和默认
|
||||
"lr_init":0.01, # 初始学习率
|
||||
"lr_end":0.0001, # 最终学习率
|
||||
"lr_end":0.0001, # 最终学习率
|
||||
"lr_max":0.1, # 最大学习率
|
||||
"save_graphs":False, # 是否保存图编译结果
|
||||
"save_graphs_path":"./graphs", # 图编译结果保存路径
|
||||
"has_trained_epoch":0, # 加载已经训练好的模型的epoch大小;实际训练周期大小等于epoch_size减去has_trained_epoch
|
||||
"has_trained_step":0, # 加载已经训练好的模型的step大小;实际训练周期大小等于step_size减去has_trained_step
|
||||
```
|
||||
|
||||
- 配置ResNet18、ResNet50和ImageNet2012数据集。
|
||||
|
@ -268,6 +272,10 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH]
|
|||
"lr_init":0, # 初始学习率
|
||||
"lr_max":0.8, # 最大学习率
|
||||
"lr_end":0.0, # 最小学习率
|
||||
"save_graphs":False, # 是否保存图编译结果
|
||||
"save_graphs_path":"./graphs", # 图编译结果保存路径
|
||||
"has_trained_epoch":0, # 加载已经训练好的模型的epoch大小;实际训练周期大小等于epoch_size减去has_trained_epoch
|
||||
"has_trained_step":0, # 加载已经训练好的模型的step大小;实际训练周期大小等于step_size减去has_trained_step
|
||||
```
|
||||
|
||||
- 配置ResNet34和ImageNet2012数据集。
|
||||
|
@ -290,6 +298,10 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH]
|
|||
"lr_init":0, # 初始学习率
|
||||
"lr_max":1.0, # 最大学习率
|
||||
"lr_end":0.0, # 最小学习率
|
||||
"save_graphs":False, # 是否保存图编译结果
|
||||
"save_graphs_path":"./graphs", # 图编译结果保存路径
|
||||
"has_trained_epoch":0, # 加载已经训练好的模型的epoch大小;实际训练周期大小等于epoch_size减去has_trained_epoch
|
||||
"has_trained_step":0, # 加载已经训练好的模型的step大小;实际训练周期大小等于step_size减去has_trained_step
|
||||
```
|
||||
|
||||
- 配置ResNet101和ImageNet2012数据集。
|
||||
|
@ -310,6 +322,10 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH]
|
|||
"use_label_smooth":True, # 标签平滑
|
||||
"label_smooth_factor":0.1, # 标签平滑因子
|
||||
"lr":0.1 # 基础学习率
|
||||
"save_graphs":False, # 是否保存图编译结果
|
||||
"save_graphs_path":"./graphs", # 图编译结果保存路径
|
||||
"has_trained_epoch":0, # 加载已经训练好的模型的epoch大小;实际训练周期大小等于epoch_size减去has_trained_epoch
|
||||
"has_trained_step":0, # 加载已经训练好的模型的step大小;实际训练周期大小等于step_size减去has_trained_step
|
||||
```
|
||||
|
||||
- 配置SE-ResNet50和ImageNet2012数据集。
|
||||
|
@ -333,6 +349,10 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH]
|
|||
"lr_init":0.0, # 初始学习率
|
||||
"lr_max":0.3, # 最大学习率
|
||||
"lr_end":0.0001, # 最终学习率
|
||||
"save_graphs":False, # 是否保存图编译结果
|
||||
"save_graphs_path":"./graphs", # 图编译结果保存路径
|
||||
"has_trained_epoch":0, # 加载已经训练好的模型的epoch大小;实际训练周期大小等于epoch_size减去has_trained_epoch
|
||||
"has_trained_step":0, # 加载已经训练好的模型的step大小;实际训练周期大小等于step_size减去has_trained_step
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
@ -410,6 +430,20 @@ bash run_standalone_train_gpu.sh [CONFIG_PATH] [RUN_EVAL](optional) [EVAL_DATASE
|
|||
|
||||
在训练结束后,可以选择关闭缓存服务器或不关闭它以继续为未来的推理提供缓存服务。
|
||||
|
||||
## 续训过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```text
|
||||
# 分布式训练
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [CONFIG_PATH] [PRETRAINED_CKPT_PATH]
|
||||
|
||||
# 单机训练
|
||||
用法:bash run_standalone_train.sh [DATASET_PATH] [CONFIG_PATH] [PRETRAINED_CKPT_PATH]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
- 使用CIFAR-10数据集训练ResNet18
|
||||
|
|
|
@ -67,6 +67,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet101_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -82,3 +88,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
|
@ -63,6 +63,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet18_cifar10"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -78,3 +84,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -65,6 +65,11 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet18_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -80,3 +85,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -65,6 +65,11 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet34_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -80,3 +85,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -66,6 +66,11 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_cifar10"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -81,3 +86,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -68,6 +68,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -83,3 +89,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -69,6 +69,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -84,3 +90,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -69,6 +69,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -84,3 +90,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -68,6 +68,11 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -83,3 +88,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -38,6 +38,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -53,3 +59,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -69,6 +69,12 @@ file_format: "AIR"
|
|||
ckpt_file: ""
|
||||
network_dataset: "se-resnet50_imagenet2012"
|
||||
|
||||
# Retrain options
|
||||
save_graphs: False
|
||||
save_graphs_path: "./graphs"
|
||||
has_trained_epoch: 0
|
||||
has_trained_step: 0
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
|
@ -84,3 +90,5 @@ batch_size: "Batch size for training and evaluation"
|
|||
epoch_size: "Total training epochs."
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
||||
save_graphs: "Whether save graphs during training, default: False."
|
||||
save_graphs_path: "Path to save graphs."
|
||||
|
|
|
@ -13,8 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn.optim import Momentum, thor, LARS
|
||||
|
@ -32,6 +35,7 @@ from mindspore.parallel import set_algo_parameters
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
import mindspore.log as logger
|
||||
|
||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.eval_callback import EvalCallBack
|
||||
|
@ -43,6 +47,38 @@ from src.resnet import conv_variance_scaling_initializer
|
|||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
class LossCallBack(LossMonitor):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
"""
|
||||
|
||||
def __init__(self, has_trained_epoch=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
self.has_trained_epoch = has_trained_epoch
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + int(self.has_trained_epoch),
|
||||
cur_step_in_epoch, loss), flush=True)
|
||||
|
||||
|
||||
if config.net_name in ("resnet18", "resnet34", "resnet50"):
|
||||
if config.net_name == "resnet18":
|
||||
from src.resnet import resnet18 as resnet
|
||||
|
@ -64,6 +100,7 @@ else:
|
|||
from src.resnet import se_resnet50 as resnet
|
||||
from src.dataset import create_dataset4 as create_dataset
|
||||
|
||||
|
||||
def acc_group_params_generator(net, weight_decay):
|
||||
acc_weight = 'end_point.weight'
|
||||
acc_multiplier = 'end_point.multiplier'
|
||||
|
@ -94,6 +131,7 @@ def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
|
|||
del origin_dict[key]
|
||||
break
|
||||
|
||||
|
||||
def apply_eval(eval_param):
|
||||
eval_model = eval_param["model"]
|
||||
eval_ds = eval_param["dataset"]
|
||||
|
@ -101,11 +139,13 @@ def apply_eval(eval_param):
|
|||
res = eval_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
|
||||
def set_graph_kernel_context(run_platform, net_name):
|
||||
if run_platform == "GPU" and net_name == "resnet101":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D")
|
||||
|
||||
|
||||
def set_parameter():
|
||||
"""set_parameter"""
|
||||
target = config.device_target
|
||||
|
@ -113,8 +153,14 @@ def set_parameter():
|
|||
config.run_distribute = False
|
||||
|
||||
# init context
|
||||
rank_save_graphs_path = os.path.join(config.save_graphs_path, "soma")
|
||||
|
||||
# Whether open graph saving
|
||||
config.save_graphs = not config.pre_trained
|
||||
|
||||
if config.mode_name == 'GRAPH':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=config.save_graphs,
|
||||
save_graphs_path=rank_save_graphs_path)
|
||||
set_graph_kernel_context(target, config.net_name)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
|
||||
|
@ -142,14 +188,44 @@ def set_parameter():
|
|||
if config.net_name == "resnet50":
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=config.all_reduce_fusion_config)
|
||||
|
||||
def init_weight(net):
|
||||
|
||||
def load_pre_trained_checkpoint():
|
||||
"""
|
||||
Load checkpoint according to pre_trained path.
|
||||
"""
|
||||
param_dict = None
|
||||
ckpt_save_dir = set_save_ckpt_dir()
|
||||
if config.pre_trained:
|
||||
if os.path.isdir(config.pre_trained):
|
||||
ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
|
||||
ckpt_files = glob.glob(ckpt_pattern)
|
||||
if not ckpt_files:
|
||||
logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
|
||||
f"pre_trained is unsupported.")
|
||||
else:
|
||||
ckpt_files.sort(key=os.path.getmtime, reverse=True)
|
||||
time_stamp = datetime.datetime.now()
|
||||
print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
|
||||
f" pre trained ckpt model {ckpt_files[0]} loading",
|
||||
flush=True)
|
||||
param_dict = load_checkpoint(ckpt_files[0])
|
||||
elif os.path.isfile(config.pre_trained):
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
else:
|
||||
print(f"Invalid pre_trained {config.pre_trained} parameter.")
|
||||
return param_dict
|
||||
|
||||
|
||||
def init_weight(net, param_dict):
|
||||
"""init_weight"""
|
||||
if config.pre_trained:
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
if config.filter_weight:
|
||||
filter_list = [x.name for x in net.end_point.get_parameters()]
|
||||
filter_checkpoint_parameter_by_list(param_dict, filter_list)
|
||||
load_param_into_net(net, param_dict)
|
||||
if param_dict:
|
||||
config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy())
|
||||
config.has_trained_step = int(param_dict["step_num"].data.asnumpy())
|
||||
if config.filter_weight:
|
||||
filter_list = [x.name for x in net.end_point.get_parameters()]
|
||||
filter_checkpoint_parameter_by_list(param_dict, filter_list)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
|
@ -174,6 +250,7 @@ def init_weight(net):
|
|||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
|
||||
cell.weight.set_data(weight)
|
||||
|
||||
|
||||
def init_lr(step_size):
|
||||
"""init lr"""
|
||||
if config.optimizer == "Thor":
|
||||
|
@ -189,6 +266,7 @@ def init_lr(step_size):
|
|||
config.pretrain_epoch_size * step_size)
|
||||
return lr
|
||||
|
||||
|
||||
def init_loss_scale():
|
||||
if config.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
|
@ -249,6 +327,7 @@ def train_net():
|
|||
"""train net"""
|
||||
target = config.device_target
|
||||
set_parameter()
|
||||
ckpt_param_dict = load_pre_trained_checkpoint()
|
||||
dataset = create_dataset(dataset_path=config.data_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target,
|
||||
distribute=config.run_distribute)
|
||||
|
@ -256,7 +335,7 @@ def train_net():
|
|||
net = resnet(class_num=config.class_num)
|
||||
if config.parameter_server:
|
||||
net.set_param_ps()
|
||||
init_weight(net=net)
|
||||
init_weight(net=net, param_dict=ckpt_param_dict)
|
||||
lr = Tensor(init_lr(step_size=step_size))
|
||||
# define opt
|
||||
group_params = init_group_prams(net)
|
||||
|
@ -293,12 +372,14 @@ def train_net():
|
|||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
loss_cb = LossCallBack(config.has_trained_epoch)
|
||||
cb = [time_cb, loss_cb]
|
||||
ckpt_save_dir = set_save_ckpt_dir()
|
||||
if config.save_checkpoint:
|
||||
ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
keep_checkpoint_max=config.keep_checkpoint_max,
|
||||
append_info=ckpt_append_info)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
run_eval(target, model, ckpt_save_dir, cb)
|
||||
|
@ -306,11 +387,13 @@ def train_net():
|
|||
if config.net_name == "se-resnet50":
|
||||
config.epoch_size = config.train_epoch_size
|
||||
dataset_sink_mode = (not config.parameter_server) and target != "CPU"
|
||||
config.pretrain_epoch_size = config.has_trained_epoch
|
||||
model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
|
||||
sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
if config.run_eval and config.enable_cache:
|
||||
print("Remember to shut down the cache server via \"cache_admin --stop\"")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_net()
|
||||
|
|
Loading…
Reference in New Issue