add gpu resnext50
This commit is contained in:
parent
ca6da6751f
commit
20e5f7196e
|
@ -90,10 +90,15 @@ sh run_standalone_train.sh DEVICE_ID DATA_PATH
|
||||||
#### Launch
|
#### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# distributed training example(8p)
|
# distributed training example(8p) for Ascend
|
||||||
sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train
|
sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train
|
||||||
# standalone training example
|
# standalone training example for Ascend
|
||||||
sh scripts/run_standalone_train.sh 0 /dataset/train
|
sh scripts/run_standalone_train.sh 0 /dataset/train
|
||||||
|
|
||||||
|
# distributed training example(8p) for GPU
|
||||||
|
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
|
||||||
|
# standalone training example for GPU
|
||||||
|
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Result
|
#### Result
|
||||||
|
@ -106,14 +111,15 @@ You can find checkpoint file together with result in log.
|
||||||
|
|
||||||
```
|
```
|
||||||
# Evaluation
|
# Evaluation
|
||||||
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH
|
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
|
||||||
```
|
```
|
||||||
|
PLATFORM is Ascend or GPU, default is Ascend.
|
||||||
|
|
||||||
#### Launch
|
#### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Evaluation with checkpoint
|
# Evaluation with checkpoint
|
||||||
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt
|
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt Ascend
|
||||||
```
|
```
|
||||||
|
|
||||||
> checkpoint can be produced in training process.
|
> checkpoint can be produced in training process.
|
||||||
|
|
|
@ -29,15 +29,11 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
from src.utils.logging import get_logger
|
from src.utils.logging import get_logger
|
||||||
|
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||||
from src.image_classification import get_network
|
from src.image_classification import get_network
|
||||||
from src.dataset import classification_dataset
|
from src.dataset import classification_dataset
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
||||||
device_target="Ascend", save_graphs=False, device_id=devid)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterReduce(nn.Cell):
|
class ParameterReduce(nn.Cell):
|
||||||
"""ParameterReduce"""
|
"""ParameterReduce"""
|
||||||
|
@ -56,6 +52,7 @@ class ParameterReduce(nn.Cell):
|
||||||
def parse_args(cloud_args=None):
|
def parse_args(cloud_args=None):
|
||||||
"""parse_args"""
|
"""parse_args"""
|
||||||
parser = argparse.ArgumentParser('mindspore classification test')
|
parser = argparse.ArgumentParser('mindspore classification test')
|
||||||
|
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
|
|
||||||
# dataset related
|
# dataset related
|
||||||
parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir')
|
parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir')
|
||||||
|
@ -108,12 +105,25 @@ def merge_args(args, cloud_args):
|
||||||
def test(cloud_args=None):
|
def test(cloud_args=None):
|
||||||
"""test"""
|
"""test"""
|
||||||
args = parse_args(cloud_args)
|
args = parse_args(cloud_args)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
|
device_target=args.platform, save_graphs=False)
|
||||||
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
|
if args.platform == "Ascend":
|
||||||
init()
|
init()
|
||||||
|
elif args.platform == "GPU":
|
||||||
|
init("nccl")
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
args.group_size = get_group_size()
|
args.group_size = get_group_size()
|
||||||
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||||
|
parameter_broadcast=True, mirror_mean=True)
|
||||||
|
else:
|
||||||
|
args.rank = 0
|
||||||
|
args.group_size = 1
|
||||||
|
|
||||||
args.outputs_dir = os.path.join(args.log_path,
|
args.outputs_dir = os.path.join(args.log_path,
|
||||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
@ -140,7 +150,7 @@ def test(cloud_args=None):
|
||||||
max_epoch=1, rank=args.rank, group_size=args.group_size,
|
max_epoch=1, rank=args.rank, group_size=args.group_size,
|
||||||
mode='eval')
|
mode='eval')
|
||||||
eval_dataloader = de_dataset.create_tuple_iterator()
|
eval_dataloader = de_dataset.create_tuple_iterator()
|
||||||
network = get_network(args.backbone, args.num_classes)
|
network = get_network(args.backbone, args.num_classes, platform=args.platform)
|
||||||
if network is None:
|
if network is None:
|
||||||
raise NotImplementedError('not implement {}'.format(args.backbone))
|
raise NotImplementedError('not implement {}'.format(args.backbone))
|
||||||
|
|
||||||
|
@ -157,12 +167,13 @@ def test(cloud_args=None):
|
||||||
load_param_into_net(network, param_dict_new)
|
load_param_into_net(network, param_dict_new)
|
||||||
args.logger.info('load model {} success'.format(model))
|
args.logger.info('load model {} success'.format(model))
|
||||||
|
|
||||||
# must add
|
|
||||||
network.add_flags_recursive(fp16=True)
|
|
||||||
|
|
||||||
img_tot = 0
|
img_tot = 0
|
||||||
top1_correct = 0
|
top1_correct = 0
|
||||||
top5_correct = 0
|
top5_correct = 0
|
||||||
|
if args.platform == "Ascend":
|
||||||
|
network.to_float(mstype.float16)
|
||||||
|
else:
|
||||||
|
auto_mixed_precision(network)
|
||||||
network.set_train(False)
|
network.set_train(False)
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
it = 0
|
it = 0
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
DATA_DIR=$1
|
||||||
|
export RANK_SIZE=8
|
||||||
|
PATH_CHECKPOINT=""
|
||||||
|
if [ $# == 2 ]
|
||||||
|
then
|
||||||
|
PATH_CHECKPOINT=$2
|
||||||
|
fi
|
||||||
|
|
||||||
|
mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||||
|
python train.py \
|
||||||
|
--is_distribute=1 \
|
||||||
|
--platform="GPU" \
|
||||||
|
--pretrained=$PATH_CHECKPOINT \
|
||||||
|
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
|
@ -14,11 +14,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
DEVICE_ID=$1
|
export DEVICE_ID=$1
|
||||||
DATA_DIR=$2
|
DATA_DIR=$2
|
||||||
PATH_CHECKPOINT=$3
|
PATH_CHECKPOINT=$3
|
||||||
|
PLATFORM=Ascend
|
||||||
|
if [ $# == 4 ]
|
||||||
|
then
|
||||||
|
PLATFORM=$4
|
||||||
|
fi
|
||||||
|
|
||||||
python eval.py \
|
python eval.py \
|
||||||
--device_id=$DEVICE_ID \
|
|
||||||
--pretrained=$PATH_CHECKPOINT \
|
--pretrained=$PATH_CHECKPOINT \
|
||||||
|
--platform=$PLATFORM \
|
||||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
DEVICE_ID=$1
|
export DEVICE_ID=$1
|
||||||
DATA_DIR=$2
|
DATA_DIR=$2
|
||||||
PATH_CHECKPOINT=""
|
PATH_CHECKPOINT=""
|
||||||
if [ $# == 3 ]
|
if [ $# == 3 ]
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
PATH_CHECKPOINT=""
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
PATH_CHECKPOINT=$3
|
||||||
|
fi
|
||||||
|
|
||||||
|
python train.py \
|
||||||
|
--is_distribute=0 \
|
||||||
|
--pretrained=$PATH_CHECKPOINT \
|
||||||
|
--platform="GPU" \
|
||||||
|
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||||
|
|
|
@ -87,7 +87,8 @@ class BasicBlock(nn.Cell):
|
||||||
"""
|
"""
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs):
|
def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False,
|
||||||
|
platform="Ascend", **kwargs):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
|
self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
|
||||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
|
@ -142,7 +143,7 @@ class Bottleneck(nn.Cell):
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
|
def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
|
||||||
base_width=64, groups=1, use_se=False, **kwargs):
|
base_width=64, groups=1, use_se=False, platform="Ascend", **kwargs):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(out_channels * (base_width / 64.0)) * groups
|
width = int(out_channels * (base_width / 64.0)) * groups
|
||||||
|
@ -153,7 +154,11 @@ class Bottleneck(nn.Cell):
|
||||||
|
|
||||||
self.conv3x3s = nn.CellList()
|
self.conv3x3s = nn.CellList()
|
||||||
|
|
||||||
|
if platform == "GPU":
|
||||||
|
self.conv2 = nn.Conv2d(width, width, 3, stride, pad_mode='pad', padding=1, group=groups)
|
||||||
|
else:
|
||||||
self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups)
|
self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups)
|
||||||
|
|
||||||
self.op_split = Split(axis=1, output_num=self.groups)
|
self.op_split = Split(axis=1, output_num=self.groups)
|
||||||
self.op_concat = Concat(axis=1)
|
self.op_concat = Concat(axis=1)
|
||||||
|
|
||||||
|
@ -211,7 +216,7 @@ class ResNet(nn.Cell):
|
||||||
Examples:
|
Examples:
|
||||||
>>>ResNet()
|
>>>ResNet()
|
||||||
"""
|
"""
|
||||||
def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False):
|
def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False, platform="Ascend"):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
self.in_channels = 64
|
self.in_channels = 64
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
|
@ -222,10 +227,10 @@ class ResNet(nn.Cell):
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||||
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se)
|
self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se, platform=platform)
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se)
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se, platform=platform)
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se)
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se, platform=platform)
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se)
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se, platform=platform)
|
||||||
|
|
||||||
self.out_channels = 512 * block.expansion
|
self.out_channels = 512 * block.expansion
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
@ -242,7 +247,7 @@ class ResNet(nn.Cell):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False):
|
def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False, platform="Ascend"):
|
||||||
"""_make_layer"""
|
"""_make_layer"""
|
||||||
down_sample = None
|
down_sample = None
|
||||||
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
||||||
|
@ -257,11 +262,12 @@ class ResNet(nn.Cell):
|
||||||
down_sample=down_sample,
|
down_sample=down_sample,
|
||||||
base_width=self.base_width,
|
base_width=self.base_width,
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
use_se=use_se))
|
use_se=use_se,
|
||||||
|
platform=platform))
|
||||||
self.in_channels = out_channels * block.expansion
|
self.in_channels = out_channels * block.expansion
|
||||||
for _ in range(1, blocks_num):
|
for _ in range(1, blocks_num):
|
||||||
layers.append(block(self.in_channels, out_channels,
|
layers.append(block(self.in_channels, out_channels, base_width=self.base_width,
|
||||||
base_width=self.base_width, groups=self.groups, use_se=use_se))
|
groups=self.groups, use_se=use_se, platform=platform))
|
||||||
|
|
||||||
return nn.SequentialCell(layers)
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
@ -269,5 +275,5 @@ class ResNet(nn.Cell):
|
||||||
return self.out_channels
|
return self.out_channels
|
||||||
|
|
||||||
|
|
||||||
def resnext50():
|
def resnext50(platform="Ascend"):
|
||||||
return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32)
|
return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32, platform=platform)
|
||||||
|
|
|
@ -36,7 +36,8 @@ config = ed({
|
||||||
"label_smooth": 1,
|
"label_smooth": 1,
|
||||||
"label_smooth_factor": 0.1,
|
"label_smooth_factor": 0.1,
|
||||||
|
|
||||||
"ckpt_interval": 1250,
|
"ckpt_interval": 5,
|
||||||
|
"ckpt_save_max": 5,
|
||||||
"ckpt_path": 'outputs/',
|
"ckpt_path": 'outputs/',
|
||||||
"is_save_on_master": 1,
|
"is_save_on_master": 1,
|
||||||
|
|
||||||
|
|
|
@ -143,8 +143,10 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank
|
||||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
||||||
de_dataset.set_dataset_size(len(sampler))
|
de_dataset.set_dataset_size(len(sampler))
|
||||||
|
|
||||||
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
|
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers,
|
||||||
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
|
operations=transform_img)
|
||||||
|
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
|
||||||
|
operations=transform_label)
|
||||||
|
|
||||||
columns_to_project = ["image", "label"]
|
columns_to_project = ["image", "label"]
|
||||||
de_dataset = de_dataset.project(columns=columns_to_project)
|
de_dataset = de_dataset.project(columns=columns_to_project)
|
||||||
|
|
|
@ -50,9 +50,9 @@ class Resnet(ImageClassificationNetwork):
|
||||||
Returns:
|
Returns:
|
||||||
Resnet.
|
Resnet.
|
||||||
"""
|
"""
|
||||||
def __init__(self, backbone_name, num_classes):
|
def __init__(self, backbone_name, num_classes, platform="Ascend"):
|
||||||
self.backbone_name = backbone_name
|
self.backbone_name = backbone_name
|
||||||
backbone = backbones.__dict__[self.backbone_name]()
|
backbone = backbones.__dict__[self.backbone_name](platform=platform)
|
||||||
out_channels = backbone.get_out_channels()
|
out_channels = backbone.get_out_channels()
|
||||||
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
|
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
|
||||||
super(Resnet, self).__init__(backbone, head)
|
super(Resnet, self).__init__(backbone, head)
|
||||||
|
@ -79,7 +79,7 @@ class Resnet(ImageClassificationNetwork):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_network(backbone_name, num_classes):
|
def get_network(backbone_name, num_classes, platform="Ascend"):
|
||||||
if backbone_name in ['resnext50']:
|
if backbone_name in ['resnext50']:
|
||||||
return Resnet(backbone_name, num_classes)
|
return Resnet(backbone_name, num_classes, platform)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Auto mixed precision."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTo(nn.Cell):
|
||||||
|
"Cast cell output back to float16 or float32"
|
||||||
|
|
||||||
|
def __init__(self, op, to_type=mstype.float16):
|
||||||
|
super(OutputTo, self).__init__(auto_prefix=False)
|
||||||
|
self._op = op
|
||||||
|
validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None)
|
||||||
|
self.to_type = to_type
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return F.cast(self._op(x), self.to_type)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_mixed_precision(network):
|
||||||
|
"""Do keep batchnorm fp32."""
|
||||||
|
cells = network.name_cells()
|
||||||
|
change = False
|
||||||
|
network.to_float(mstype.float16)
|
||||||
|
for name in cells:
|
||||||
|
subcell = cells[name]
|
||||||
|
if subcell == network:
|
||||||
|
continue
|
||||||
|
elif name == 'fc':
|
||||||
|
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
|
||||||
|
change = True
|
||||||
|
elif name == 'conv2':
|
||||||
|
subcell.to_float(mstype.float32)
|
||||||
|
change = True
|
||||||
|
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||||
|
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
|
||||||
|
change = True
|
||||||
|
else:
|
||||||
|
auto_mixed_precision(subcell)
|
||||||
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
|
network.cell_list = list(network.cells())
|
|
@ -29,14 +29,10 @@ class GlobalAvgPooling(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GlobalAvgPooling, self).__init__()
|
super(GlobalAvgPooling, self).__init__()
|
||||||
self.mean = P.ReduceMean(True)
|
self.mean = P.ReduceMean(False)
|
||||||
self.shape = P.Shape()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.mean(x, (2, 3))
|
x = self.mean(x, (2, 3))
|
||||||
b, c, _, _ = self.shape(x)
|
|
||||||
x = self.reshape(x, (b, c))
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,11 +36,9 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
|
||||||
from src.utils.logging import get_logger
|
from src.utils.logging import get_logger
|
||||||
from src.utils.optimizers__init__ import get_param_groups
|
from src.utils.optimizers__init__ import get_param_groups
|
||||||
from src.image_classification import get_network
|
from src.image_classification import get_network
|
||||||
|
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
||||||
device_target="Ascend", save_graphs=False, device_id=devid)
|
|
||||||
|
|
||||||
class BuildTrainNetwork(nn.Cell):
|
class BuildTrainNetwork(nn.Cell):
|
||||||
"""build training network"""
|
"""build training network"""
|
||||||
|
@ -109,6 +107,7 @@ class ProgressMonitor(Callback):
|
||||||
def parse_args(cloud_args=None):
|
def parse_args(cloud_args=None):
|
||||||
"""parameters"""
|
"""parameters"""
|
||||||
parser = argparse.ArgumentParser('mindspore classification training')
|
parser = argparse.ArgumentParser('mindspore classification training')
|
||||||
|
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
|
|
||||||
# dataset related
|
# dataset related
|
||||||
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
||||||
|
@ -141,6 +140,7 @@ def parse_args(cloud_args=None):
|
||||||
args.label_smooth = config.label_smooth
|
args.label_smooth = config.label_smooth
|
||||||
args.label_smooth_factor = config.label_smooth_factor
|
args.label_smooth_factor = config.label_smooth_factor
|
||||||
args.ckpt_interval = config.ckpt_interval
|
args.ckpt_interval = config.ckpt_interval
|
||||||
|
args.ckpt_save_max = config.ckpt_save_max
|
||||||
args.ckpt_path = config.ckpt_path
|
args.ckpt_path = config.ckpt_path
|
||||||
args.is_save_on_master = config.is_save_on_master
|
args.is_save_on_master = config.is_save_on_master
|
||||||
args.rank = config.rank
|
args.rank = config.rank
|
||||||
|
@ -166,12 +166,25 @@ def merge_args(args, cloud_args):
|
||||||
def train(cloud_args=None):
|
def train(cloud_args=None):
|
||||||
"""training process"""
|
"""training process"""
|
||||||
args = parse_args(cloud_args)
|
args = parse_args(cloud_args)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
|
device_target=args.platform, save_graphs=False)
|
||||||
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
|
if args.platform == "Ascend":
|
||||||
init()
|
init()
|
||||||
|
else:
|
||||||
|
init("nccl")
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
args.group_size = get_group_size()
|
args.group_size = get_group_size()
|
||||||
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||||
|
parameter_broadcast=True, mirror_mean=True)
|
||||||
|
else:
|
||||||
|
args.rank = 0
|
||||||
|
args.group_size = 1
|
||||||
|
|
||||||
if args.is_dynamic_loss_scale == 1:
|
if args.is_dynamic_loss_scale == 1:
|
||||||
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
|
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
|
||||||
|
@ -192,7 +205,7 @@ def train(cloud_args=None):
|
||||||
# dataloader
|
# dataloader
|
||||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||||
args.per_batch_size, 1,
|
args.per_batch_size, 1,
|
||||||
args.rank, args.group_size)
|
args.rank, args.group_size, num_parallel_workers=8)
|
||||||
de_dataset.map_model = 4 # !!!important
|
de_dataset.map_model = 4 # !!!important
|
||||||
args.steps_per_epoch = de_dataset.get_dataset_size()
|
args.steps_per_epoch = de_dataset.get_dataset_size()
|
||||||
|
|
||||||
|
@ -201,15 +214,9 @@ def train(cloud_args=None):
|
||||||
# network
|
# network
|
||||||
args.logger.important_info('start create network')
|
args.logger.important_info('start create network')
|
||||||
# get network and init
|
# get network and init
|
||||||
network = get_network(args.backbone, args.num_classes)
|
network = get_network(args.backbone, args.num_classes, platform=args.platform)
|
||||||
if network is None:
|
if network is None:
|
||||||
raise NotImplementedError('not implement {}'.format(args.backbone))
|
raise NotImplementedError('not implement {}'.format(args.backbone))
|
||||||
network.add_flags_recursive(fp16=True)
|
|
||||||
# loss
|
|
||||||
if not args.label_smooth:
|
|
||||||
args.label_smooth_factor = 0.0
|
|
||||||
criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
|
|
||||||
num_classes=args.num_classes)
|
|
||||||
|
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
if os.path.isfile(args.pretrained):
|
if os.path.isfile(args.pretrained):
|
||||||
|
@ -252,31 +259,29 @@ def train(cloud_args=None):
|
||||||
loss_scale=args.loss_scale)
|
loss_scale=args.loss_scale)
|
||||||
|
|
||||||
|
|
||||||
criterion.add_flags_recursive(fp32=True)
|
# loss
|
||||||
|
if not args.label_smooth:
|
||||||
|
args.label_smooth_factor = 0.0
|
||||||
|
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||||
|
|
||||||
# package training process, adjust lr + forward + backward + optimizer
|
|
||||||
train_net = BuildTrainNetwork(network, criterion)
|
|
||||||
if args.is_distributed:
|
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
||||||
else:
|
|
||||||
parallel_mode = ParallelMode.STAND_ALONE
|
|
||||||
if args.is_dynamic_loss_scale == 1:
|
if args.is_dynamic_loss_scale == 1:
|
||||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||||
else:
|
else:
|
||||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
# Model api changed since TR5_branch 2020/03/09
|
if args.platform == "Ascend":
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||||
parameter_broadcast=True, mirror_mean=True)
|
metrics={'acc'}, amp_level="O3")
|
||||||
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager)
|
else:
|
||||||
|
auto_mixed_precision(network)
|
||||||
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
|
||||||
|
|
||||||
# checkpoint save
|
# checkpoint save
|
||||||
progress_cb = ProgressMonitor(args)
|
progress_cb = ProgressMonitor(args)
|
||||||
callbacks = [progress_cb,]
|
callbacks = [progress_cb,]
|
||||||
if args.rank_save_ckpt_flag:
|
if args.rank_save_ckpt_flag:
|
||||||
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
keep_checkpoint_max=args.ckpt_save_max)
|
||||||
keep_checkpoint_max=ckpt_max_num)
|
|
||||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
directory=args.outputs_dir,
|
directory=args.outputs_dir,
|
||||||
prefix='{}'.format(args.rank))
|
prefix='{}'.format(args.rank))
|
||||||
|
|
Loading…
Reference in New Issue