!22504 add lessBN optimization in mobilenetv2

Merge pull request !22504 from guoqi/mobilenet-lessbn
This commit is contained in:
i-robot 2021-09-04 01:34:42 +00:00 committed by Gitee
commit b4f4dcb514
17 changed files with 537 additions and 273 deletions

View File

@ -356,7 +356,10 @@ bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int
}
const auto &use_pattern = std::get<1>(patternTuple);
int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
return IsPrimitiveCNode(cnode, use_pattern[IntToSize(use_index)]);
if (!IsPrimitiveCNode(cnode, use_pattern[use_index]) && use_pattern[use_index] != prim::kPrimTupleGetItem) {
return false;
}
return true;
}
bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
@ -384,7 +387,8 @@ bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
}
void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) &&
!IsValueNode<FuncGraph>(cnode->input(0))) {
return;
}
if (match_pattern.empty()) {

View File

@ -76,7 +76,7 @@ class AutoAcc:
optimizer_process.origin_params = \
self._param_processer.generate_group_params(group_params, optimizer_process.origin_params)
if self._gc_flag:
optimizer_process.add_grad_centralization()
optimizer_process.add_grad_centralization(network)
optimizer = optimizer_process.generate_new_optimizer()
if self._acc_config["grad_freeze"]:

View File

@ -1,166 +1,187 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""base process"""
import copy
from mindspore.nn.optim import LARS
from mindspore import log as logger
from mindspore.common import Parameter
__all__ = ["OptimizerProcess", "ParameterProcess"]
class OptimizerProcess:
"""
Process optimizer for ACC.
Args:
opt (Cell): Optimizer used.
"""
def __init__(self, opt):
if isinstance(opt, LARS):
self.is_lars = True
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
else:
self.is_lars = False
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.origin_params = opt.init_params["params"]
def add_grad_centralization(self):
"""Add gradient centralization."""
parameters = self.origin_params
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
if not parameters:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(parameters[0], Parameter):
logger.warning("Only group parameters support gradient centralization.")
return
group_params = []
for group_param in parameters:
if 'order_params' in group_param.keys():
group_params.append(group_param)
continue
params_gc_value = []
params_value = []
for param in group_param['params']:
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
params_gc_value.append(param)
else:
params_value.append(param)
if params_gc_value:
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_gc_value
new_group_param['grad_centralization'] = True
group_params.append(new_group_param)
if params_value:
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
self.origin_params = group_params
def generate_new_optimizer(self):
"""Generate new optimizer."""
if not self.is_lars:
opt = self.opt_class(params=self.origin_params, **self.opt_init_args)
else:
opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args)
return opt
class ParameterProcess:
"""
Process parameter for ACC.
"""
def __init__(self):
self._parameter_indices = 1
def assign_parameter_group(self, parameters, split_point=None):
"""Assign parameter group."""
if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters
parameter_len = len(parameters)
if split_point:
split_parameter_index = split_point
else:
split_parameter_index = [parameter_len // 2]
for i in range(parameter_len):
if i in split_parameter_index:
self._parameter_indices += 1
parameters[i].comm_fusion = self._parameter_indices
return parameters
def generate_group_params(self, parameters, origin_params):
"""Generate group parameters."""
origin_params_copy = origin_params
if origin_params_copy is not None:
if not isinstance(origin_params_copy, list):
origin_params_copy = list(origin_params_copy)
if not origin_params_copy:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(origin_params_copy[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(origin_params_copy[0], Parameter):
group_params = [{"params": parameters}]
return group_params
group_params = []
params_name = [param.name for param in parameters]
new_params_count = copy.deepcopy(params_name)
new_params_clone = {}
max_key_number = 0
for group_param in origin_params_copy:
if 'order_params' in group_param.keys():
new_group_param = copy.deepcopy(group_param)
new_group_param['order_params'] = parameters
group_params.append(new_group_param)
continue
params_value = []
for param in group_param['params']:
if param.name in params_name:
index = params_name.index(param.name)
params_value.append(parameters[index])
new_params_count.remove(param.name)
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
if len(group_param.keys()) > max_key_number:
max_key_number = len(group_param.keys())
new_params_clone = copy.deepcopy(group_param)
if new_params_count:
params_value = []
for param in new_params_count:
index = params_name.index(param)
params_value.append(parameters[index])
if new_params_clone:
new_params_clone['params'] = params_value
group_params.append(new_params_clone)
else:
group_params.append({"params": params_value})
return group_params
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""base process"""
import copy
import mindspore.nn as nn
from mindspore.nn.optim import LARS
from mindspore import log as logger
from mindspore.common import Parameter
from .less_batch_normalization import CommonHeadLastFN
__all__ = ["OptimizerProcess", "ParameterProcess"]
class OptimizerProcess:
"""
Process optimizer for ACC.
Args:
opt (Cell): Optimizer used.
"""
def __init__(self, opt):
if isinstance(opt, LARS):
self.is_lars = True
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
else:
self.is_lars = False
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.origin_params = opt.init_params["params"]
def build_params_dict(self, network):
"""Build the params dict of the network"""
cells = network.cells_and_names()
params_dict = {}
for _, cell in cells:
for par in cell.get_parameters(expand=False):
params_dict[id(par)] = cell
return params_dict
def build_gc_params_group(self, params_dict, parameters):
"""Build the params group that needs gc"""
group_params = []
for group_param in parameters:
if 'order_params' in group_param.keys():
group_params.append(group_param)
continue
params_gc_value = []
params_value = []
for param in group_param['params']:
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
param_cell = params_dict[id(param)]
if (isinstance(param_cell, nn.Conv2d) and param_cell.group > 1) or \
isinstance(param_cell, CommonHeadLastFN):
params_value.append(param)
else:
params_gc_value.append(param)
else:
params_value.append(param)
if params_gc_value:
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_gc_value
new_group_param['grad_centralization'] = True
group_params.append(new_group_param)
if params_value:
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
return group_params
def add_grad_centralization(self, network):
"""Add gradient centralization."""
params_dict = self.build_params_dict(network)
parameters = self.origin_params
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
if not parameters:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(parameters[0], Parameter):
logger.warning("Only group parameters support gradient centralization.")
return
self.origin_params = self.build_gc_params_group(params_dict, parameters)
def generate_new_optimizer(self):
"""Generate new optimizer."""
if not self.is_lars:
opt = self.opt_class(params=self.origin_params, **self.opt_init_args)
else:
opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args)
return opt
class ParameterProcess:
"""
Process parameter for ACC.
"""
def __init__(self):
self._parameter_indices = 1
def assign_parameter_group(self, parameters, split_point=None):
"""Assign parameter group."""
if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters
parameter_len = len(parameters)
if split_point:
split_parameter_index = split_point
else:
split_parameter_index = [parameter_len // 2]
for i in range(parameter_len):
if i in split_parameter_index:
self._parameter_indices += 1
parameters[i].comm_fusion = self._parameter_indices
return parameters
def generate_group_params(self, parameters, origin_params):
"""Generate group parameters."""
origin_params_copy = origin_params
if origin_params_copy is not None:
if not isinstance(origin_params_copy, list):
origin_params_copy = list(origin_params_copy)
if not origin_params_copy:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(origin_params_copy[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(origin_params_copy[0], Parameter):
group_params = [{"params": parameters}]
return group_params
group_params = []
params_name = [param.name for param in parameters]
new_params_count = copy.deepcopy(params_name)
new_params_clone = {}
max_key_number = 0
for group_param in origin_params_copy:
if 'order_params' in group_param.keys():
new_group_param = copy.deepcopy(group_param)
new_group_param['order_params'] = parameters
group_params.append(new_group_param)
continue
params_value = []
for param in group_param['params']:
if param.name in params_name:
index = params_name.index(param.name)
params_value.append(parameters[index])
new_params_count.remove(param.name)
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
if len(group_param.keys()) > max_key_number:
max_key_number = len(group_param.keys())
new_params_clone = copy.deepcopy(group_param)
if new_params_count:
params_value = []
for param in new_params_count:
index = params_name.index(param)
params_value.append(parameters[index])
if new_params_clone:
new_params_clone['params'] = params_value
group_params.append(new_params_clone)
else:
group_params.append({"params": params_value})
return group_params

View File

@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
__all__ = ["LessBN"]
__all__ = ["CommonHeadLastFN", "LessBN"]
class CommonHeadLastFN(Cell):

View File

@ -268,34 +268,34 @@ You can start training using python or shell scripts. The usage of shell scripts
```shell
# training example
python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH]
Ascend: python train.py --platform Ascend --config_path [CONFIG_PATH] --dataset_path [TRAIN_DATASET_PATH]
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH]
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH]
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH]
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH]
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH]
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH]
# fine tune whole network example
python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
Ascend: python train.py --platform Ascend --config_path [CONFIG_PATH] --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none True
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none True
# fine tune full connected layers example
python:
Ascend: python --platform Ascend train.py --dataset_path [TRAIN_DATASET_PATH]--pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
GPU: python --platform GPU train.py --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
CPU: python --platform CPU train.py --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
Ascend: python train.py --platform Ascend --config_path default_config.yaml --dataset_path [TRAIN_DATASET_PATH]--pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
```

View File

@ -211,6 +211,7 @@ MobileNetV2总体网络架构如下
│ ├──local_adapter.py # 获取本地id
│ └──moxing_adapter.py # 云上数据准备
├── default_config.yaml # 训练配置参数(ascend)
├── default_config_acc.yaml # 训练配置参数(ascend acc模式)
├── default_config_cpu.yaml # 训练配置参数(cpu)
├── default_config_gpu.yaml # 训练配置参数(gpu)
├── train.py # 训练脚本
@ -226,7 +227,7 @@ MobileNetV2总体网络架构如下
使用python或shell脚本开始训练。shell脚本的使用方法如下
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- Ascend: sh run_train.sh Ascend [CONFIG_PATH] [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- GPU: bash run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- CPU: bash run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
@ -269,34 +270,34 @@ MobileNetV2总体网络架构如下
```shell
# 训练示例
python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH]
Ascend: python train.py --platform Ascend --config_path [CONFIG_PATH] --dataset_path [TRAIN_DATASET_PATH]
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH]
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH]
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH]
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH]
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH]
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH]
# 全网微调示例
python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
Ascend: python train.py --platform Ascend --config_path [CONFIG_PATH] --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none True
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none True
# 全连接层微调示例
python:
Ascend: python --platform Ascend train.py --dataset_path [TRAIN_DATASET_PATH]--pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
GPU: python --platform GPU train.py --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
CPU: python --platform CPU train.py --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
Ascend: python train.py --platform Ascend --config_path default_config.yaml --dataset_path [TRAIN_DATASET_PATH]--pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer backbone
shell:
Ascend: bash run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
Ascend: bash run_train.sh Ascend default_config.yaml 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
GPU: bash run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
CPU: bash run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] backbone
```

View File

@ -17,6 +17,8 @@ need_modelarts_dataset_unzip: True
num_classes: 1000
image_height: 224
image_width: 224
num_workers: 32
acc_mode: "O0"
batch_size: 256
epoch_size: 200
warmup_epochs: 4

View File

@ -0,0 +1,97 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
device_target: Ascend
enable_profiling: False
# ==============================================================================
modelarts_dataset_unzip_name: 'ImageNet_Original'
need_modelarts_dataset_unzip: True
num_classes: 1000
image_height: 224
image_width: 224
num_workers: 32
acc_mode: "O1"
batch_size: 256
epoch_size: 200
warmup_epochs: 4
lr_init: 0.00
lr_end: 0.00
lr_max: 0.4
momentum: 0.9
weight_decay: 0.00004
label_smooth: 0.1
loss_scale: 1024
save_checkpoint: False
save_checkpoint_epochs: 1
keep_checkpoint_max: 200
save_checkpoint_path: "./"
platform: 'Ascend'
device_id: 0
rank_id: 0
rank_size: 1
run_distribute: False
activation: "Softmax"
# Image classification trian. train_parse_args():return train_args
dataset_path: "/cache/data"
pretrain_ckpt: "./mobilenetv2-200_625.ckpt"
freeze_layer: ""
filter_head: False
enable_cache: False
cache_session_id: ""
is_training: True
# mobilenetv2 eval
is_training_eval: False
run_distribute_eval: False
# mobilenetv2 export
device_id_export: 0
batch_size_export: 1
ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt"
file_name: "mobilenetv2"
file_format: "MINDIR"
is_training_export: False
run_distribute_export: False
# postprocess.py / mobilenetv2 acc calculation
batch_size_postprocess: 1
result_path: ''
label_path: ''
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
ann_file: 'Ann file, default is val.json.'
pretrain_ckpt: 'Pretrained checkpoint path for fine tune or incremental learning'
platform: 'Target device type'
freeze_layer: 'freeze the weights of network from start to which layers'
filter_head: 'Filter head weight parameters when load checkpoint, default is False.'
enable_cache: 'Caching the dataset in memory to speedup dataset processing, default is False.'
cache_session_id: 'The session id for cache service.'
file_name: "output file name."
result_path: "result files path."
label_path: "label path."
enable_profiling: 'Whether enable profiling while training, default: False'
run_distribute: 'Run distribute, default is false.'
device_id: 'Device id, default is 0.'
rank_id: 'Rank id, default is 0.'
file_format: 'file format'
---
platform: ['Ascend', 'GPU', 'CPU']
file_format: ["AIR", "ONNX", "MINDIR"]
freeze_layer: ["", "none", "backbone"]

View File

@ -17,6 +17,8 @@ need_modelarts_dataset_unzip: True
num_classes: 26
image_height: 224
image_width: 224
num_workers: 8
acc_mode: "O0"
batch_size: 150
epoch_size: 15
warmup_epochs: 0

View File

@ -17,6 +17,8 @@ need_modelarts_dataset_unzip: True
num_classes: 1000
image_height: 224
image_width: 224
num_workers: 8
acc_mode: "O0"
batch_size: 150
epoch_size: 200
warmup_epochs: 0

View File

@ -16,50 +16,50 @@
run_ascend()
{
if [ $# = 5 ] ; then
if [ $# = 6 ] ; then
PRETRAINED_CKPT=""
FREEZE_LAYER="none"
FILTER_HEAD="False"
elif [ $# = 7 ] ; then
PRETRAINED_CKPT=$6
FREEZE_LAYER=$7
FILTER_HEAD="False"
elif [ $# = 8 ] ; then
PRETRAINED_CKPT=$6
FREEZE_LAYER=$7
FILTER_HEAD=$8
PRETRAINED_CKPT=$7
FREEZE_LAYER=$8
FILTER_HEAD="False"
elif [ $# = 9 ] ; then
PRETRAINED_CKPT=$7
FREEZE_LAYER=$8
FILTER_HEAD=$9
else
echo "Usage:
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH](optional) [FREEZE_LAYER](optional) [FILTER_HEAD](optional)
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]"
Ascend: sh run_train.sh Ascend [CONFIG_FILE] [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH](optional) [FREEZE_LAYER](optional) [FILTER_HEAD](optional)
Ascend: sh run_train.sh Ascend [CONFIG_FILE] [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi;
if [ $2 -lt 1 ] || [ $2 -gt 8 ]
if [ $3 -lt 1 ] || [ $3 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
echo "error: DEVICE_NUM=$3 is not in (1-8)"
exit 1
fi
if [ ! -d $5 ] && [ ! -f $5 ]
if [ ! -d $6 ] && [ ! -f $6 ]
then
echo "error: DATASET_PATH=$5 is not a directory or file"
echo "error: DATASET_PATH=$6 is not a directory or file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASEPATH}/../default_config.yaml"
CONFIG_FILE="${BASEPATH}/../$2"
VISIABLE_DEVICES=$3
VISIABLE_DEVICES=$4
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
if [ ${#CANDIDATE_DEVICE[@]} -ne $3 ]
then
echo "error: DEVICE_NUM=$2 is not equal to the length of VISIABLE_DEVICES=$3"
echo "error: DEVICE_NUM=$3 is not equal to the length of VISIABLE_DEVICES=$4"
exit 1
fi
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export RANK_TABLE_FILE=$4
export RANK_SIZE=$2
export RANK_TABLE_FILE=$5
export RANK_SIZE=$3
if [ -d "../train" ];
then
rm -rf ../train
@ -87,7 +87,7 @@ run_ascend()
taskset -c $cmdopt python train.py \
--config_path=$CONFIG_FILE \
--platform=$1 \
--dataset_path=$5 \
--dataset_path=$6 \
--pretrain_ckpt=$PRETRAINED_CKPT \
--freeze_layer=$FREEZE_LAYER \
--filter_head=$FILTER_HEAD \

View File

@ -48,30 +48,32 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1, enable_cache=Fa
else:
nfs_dataset_cache = None
num_workers = config.num_workers
if config.platform == "Ascend":
rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0'))
if rank_size == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
cache=nfs_dataset_cache)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
num_shards=rank_size, shard_id=rank_id, cache=nfs_dataset_cache)
elif config.platform == "GPU":
if do_train:
if config.run_distribute:
from mindspore.communication.management import get_rank, get_group_size
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank(),
cache=nfs_dataset_cache)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
cache=nfs_dataset_cache)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
cache=nfs_dataset_cache)
elif config.platform == "CPU":
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, cache=nfs_dataset_cache)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, \
shuffle=True, cache=nfs_dataset_cache)
resize_height = config.image_height
resize_width = config.image_width
@ -96,7 +98,7 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1, enable_cache=Fa
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=num_workers)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
# apply shuffle operations

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation metric."""
from mindspore.communication.management import GlobalComm
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
class ClassifyCorrectCell(nn.Cell):
r"""
Cell that returns correct count of the prediction in classification network.
This Cell accepts a network as arguments.
It returns orrect count of the prediction to calculate the metrics.
Args:
network (Cell): The network Cell.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tuple, containing a scalar correct count of the prediction
Examples:
>>> # For a defined network Net without loss function
>>> net = Net()
>>> eval_net = nn.ClassifyCorrectCell(net)
"""
def __init__(self, network):
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
self._network = network
self.argmax = P.Argmax()
self.equal = P.Equal()
self.cast = P.Cast()
self.reduce_sum = P.ReduceSum()
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
def construct(self, data, label):
outputs = self._network(data)
y_pred = self.argmax(outputs)
y_pred = self.cast(y_pred, mstype.int32)
y_correct = self.equal(y_pred, label)
y_correct = self.cast(y_correct, mstype.float32)
y_correct = self.reduce_sum(y_correct)
total_correct = self.allreduce(y_correct)
return (total_correct,)
class DistAccuracy(nn.Metric):
r"""
Calculates the accuracy for classification data in distributed mode.
The accuracy class creates two local variables, correct number and total number that are used to compute the
frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
idempotent operation that simply divides correct number by total number.
.. math::
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
Args:
eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label).
Examples:
>>> y_correct = Tensor(np.array([20]))
>>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
>>> metric.clear()
>>> metric.update(y_correct)
>>> accuracy = metric.eval()
"""
def __init__(self, batch_size, device_num):
super(DistAccuracy, self).__init__()
self.clear()
self.batch_size = batch_size
self.device_num = device_num
def clear(self):
"""Clears the internal evaluation result."""
self._correct_num = 0
self._total_num = 0
def update(self, *inputs):
"""
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
Args:
inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
`y_correct` is the right prediction count that gathered from all devices
it's a scalar in float type
Raises:
ValueError: If the number of the input is not 1.
"""
if len(inputs) != 1:
raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
y_correct = self._convert_data(inputs[0])
self._correct_num += y_correct
self._total_num += self.batch_size * self.device_num
def eval(self):
"""
Computes the accuracy.
Returns:
Float, the computed result.
Raises:
RuntimeError: If the sample size is 0.
"""
if self._total_num == 0:
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num

View File

@ -19,7 +19,7 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import Add
from mindspore import Tensor
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
__all__ = ['MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
def _make_divisible(v, divisor, min_value=None):
@ -245,9 +245,10 @@ class MobileNetV2Head(nn.Cell):
def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"):
super(MobileNetV2Head, self).__init__()
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])
head = ([GlobalAvgPooling()] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2)])
self.head = nn.SequentialCell(head)
self.dense = nn.Dense(input_channel, num_classes, has_bias=True)
self.need_activation = True
if activation == "Sigmoid":
self.activation = P.Sigmoid()
@ -259,6 +260,7 @@ class MobileNetV2Head(nn.Cell):
def construct(self, x):
x = self.head(x)
x = self.dense(x)
if self.need_activation:
x = self.activation(x)
return x
@ -283,41 +285,6 @@ class MobileNetV2Head(nn.Cell):
if m.bias is not None:
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
@property
def get_head(self):
return self.head
class MobileNetV2(nn.Cell):
"""
MobileNetV2 architecture.
Args:
class_num (int): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.
Examples:
>>> MobileNetV2(backbone, head)
"""
def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \
round_nearest=8, input_channel=32, last_channel=1280):
super(MobileNetV2, self).__init__()
self.backbone = MobileNetV2Backbone(width_mult=width_mult, \
inverted_residual_setting=inverted_residual_setting, \
round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_features
self.head = MobileNetV2Head(input_channel=self.backbone.out_channels, num_classes=num_classes, \
has_dropout=has_dropout).get_head
def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x
class MobileNetV2Combine(nn.Cell):

View File

@ -70,10 +70,13 @@ class Monitor(Callback):
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
def __init__(self, lr_init=None, model=None, eval_dataset=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
self.model = model
self.eval_dataset = eval_dataset
self.best_acc = 0.
def epoch_begin(self, run_context):
self.losses = []
@ -84,9 +87,18 @@ class Monitor(Callback):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
eval_acc = None
if self.model is not None and self.eval_dataset is not None:
eval_acc = self.model.eval(self.eval_dataset)["acc"]
log = "epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
np.mean(self.losses))
if eval_acc is not None:
if eval_acc > self.best_acc:
self.best_acc = eval_acc
log += ", eval_acc: {:.6f}, best_acc: {:.6f}".format(eval_acc, self.best_acc)
print(log, flush=True)
def step_begin(self, run_context):
self.step_time = time.time()

View File

@ -67,10 +67,9 @@ def set_context(config):
device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size):
cb = None
def config_ckpoint(config, lr, step_size, model=None, eval_dataset=None):
cb = [Monitor(lr_init=lr.asnumpy(), model=model, eval_dataset=eval_dataset)]
if config.platform in ("CPU", "GPU") or config.rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,

View File

@ -24,7 +24,6 @@ from mindspore import Tensor
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype
from mindspore.communication.management import get_rank
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager
@ -33,8 +32,9 @@ from mindspore.common import set_seed
from src.dataset import create_dataset, extract_features
from src.lr_generator import get_lr
from src.utils import context_device_init, switch_precision, config_ckpoint
from src.utils import context_device_init, config_ckpoint
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt
from src.metric import DistAccuracy, ClassifyCorrectCell
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
@ -97,10 +97,25 @@ def modelarts_pre_process():
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.pretrain_ckpt = os.path.join(config.output_path, config.pretrain_ckpt)
def build_params_groups(net):
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
return group_params
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_mobilenetv2():
config.dataset_path = os.path.join(config.dataset_path, 'train')
config.train_dataset_path = os.path.join(config.dataset_path, 'train')
config.eval_dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
config.device_id = get_device_id()
config.rank_id = get_rank_id()
@ -108,22 +123,23 @@ def train_mobilenetv2():
if config.platform == 'Ascend':
config.run_distribute = config.rank_size > 1.
print('\nconfig: \n', config)
print('\nconfig: {} \n'.format(config))
start = time.time()
# set context and device init
context_device_init(config)
# define network
backbone_net, head_net, net = define_net(config, config.is_training)
dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, config=config,
dataset = create_dataset(dataset_path=config.train_dataset_path, do_train=True, config=config,
enable_cache=config.enable_cache, cache_session_id=config.cache_session_id)
eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False, config=config)
step_size = dataset.get_dataset_size()
if config.platform == "GPU":
context.set_context(enable_graph_kernel=True)
if config.pretrain_ckpt:
if config.freeze_layer == "backbone":
load_ckpt(backbone_net, config.pretrain_ckpt, trainable=False)
step_size = extract_features(backbone_net, config.dataset_path, config)
step_size = extract_features(backbone_net, config.train_dataset_path, config)
elif config.filter_head:
load_ckpt(backbone_net, config.pretrain_ckpt)
else:
@ -132,9 +148,6 @@ def train_mobilenetv2():
raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \
than batch_size in config.py")
# Currently, only Ascend support switch precision.
switch_precision(net, mstype.float16, config)
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
@ -155,11 +168,21 @@ def train_mobilenetv2():
if config.pretrain_ckpt == "" or config.freeze_layer != "backbone":
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
group_params = build_params_groups(net)
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
cb = config_ckpoint(config, lr, step_size)
metrics = {"acc"}
dist_eval_network = None
if config.run_distribute:
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.rank_size)}
dist_eval_network = ClassifyCorrectCell(net)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
metrics=metrics, eval_network=dist_eval_network,
amp_level="O2", keep_batchnorm_fp32=False,
acc_level=config.acc_mode)
cb = config_ckpoint(config, lr, step_size, model, eval_dataset)
print("============== Starting Training ==============")
model.train(epoch_size, dataset, callbacks=cb)
print("============== End Training ==============")
@ -172,7 +195,7 @@ def train_mobilenetv2():
network = TrainOneStepCell(network, opt)
network.set_train()
features_path = config.dataset_path + '_features'
features_path = config.train_dataset_path + '_features'
idx_list = list(range(step_size))
rank = 0
if config.run_distribute: