forked from mindspore-Ecosystem/mindspore
add pre-training scripts
This commit is contained in:
parent
11cf74e6e8
commit
aad655b778
|
@ -23,7 +23,7 @@
|
|||
|
||||
ICNet(Image Cascade Network) propose a full convolution network which incorporates multi-resolution branches under proper label guidance to address the challenge of real-time semantic segmentation.
|
||||
|
||||
[paper](https://arxiv.org/abs/1704.08545)ECCV2018
|
||||
[paper](https://arxiv.org/abs/1704.08545) from ECCV2018
|
||||
|
||||
# [Model Architecture](#Contents)
|
||||
|
||||
|
@ -31,7 +31,7 @@ ICNet takes cascade image inputs (i.e., low-, medium- and high resolution images
|
|||
|
||||
# [Dataset](#Content)
|
||||
|
||||
used Dataset :[Cityscape Dataset Website](https://www.cityscapes-dataset.com/)
|
||||
used Dataset :[Cityscape Dataset Website](https://www.cityscapes-dataset.com/) (please download 1st and 3rd zip)
|
||||
|
||||
It contains 5,000 finely annotated images split into training, validation and testing sets with 2,975, 500, and 1,525 images respectively.
|
||||
|
||||
|
@ -64,6 +64,16 @@ It contains 5,000 finely annotated images split into training, validation and te
|
|||
├── export.py # export mindir
|
||||
├── postprocess.py # 310 infer calculate accuracy
|
||||
├── README.md # descriptions about ICNet
|
||||
├── Res50V1_PRE # scripts for pretrain
|
||||
│ ├── scripts
|
||||
│ │ └── run_distribute_train.sh
|
||||
│ ├── src
|
||||
│ │ ├── config.py
|
||||
│ │ ├── CrossEntropySmooth.py
|
||||
│ │ ├── dataset.py
|
||||
│ │ ├── lr_generator.py
|
||||
│ │ └── resnet50_v1.py
|
||||
│ └── train.py
|
||||
├── scripts
|
||||
│ ├── run_distribute_train8p.sh # multi cards distributed training in ascend
|
||||
│ ├── run_eval.sh # validation script
|
||||
|
@ -95,7 +105,7 @@ Set script parameters in src/model_utils/icnet.yaml .
|
|||
|
||||
```bash
|
||||
name: "icnet"
|
||||
backbone: "resnet50"
|
||||
backbone: "resnet50v1"
|
||||
base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0]
|
||||
crop_size: 960 # end of augmentation, crop to training
|
||||
```
|
||||
|
@ -116,9 +126,8 @@ valid_batch_size: 1
|
|||
cityscapes_root: "/data/cityscapes/" # set dataset path
|
||||
epochs: 160
|
||||
val_epoch: 1
|
||||
ckpt_dir: "./ckpt/" # ckpt and training log will be saved here
|
||||
mindrecord_dir: '' # set mindrecord path
|
||||
pretrained_model_path: '/root/ResNet50V1B-150_625.ckpt' # set the pretrained model path correctly
|
||||
pretrained_model_path: '/root/ResNet50V1B-150_625.ckpt' # use the latest checkpoint file after pre-training
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
```
|
||||
|
@ -137,18 +146,28 @@ keep_checkpoint_max: 10
|
|||
|
||||
[MINDRCORD_PATH] in script should be consistent with 'mindrecord_dir' in config file.
|
||||
|
||||
### Distributed Training
|
||||
### Pre-training
|
||||
|
||||
- Run distributed train in ascend processor environment
|
||||
The folder Res50V1_PRE contains the scripts for pre-training and its dataset is [image net](https://image-net.org/). More details in [GENet_Res50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/GENet_Res50)
|
||||
|
||||
- Usage:
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PROJECT_PATH]
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
- Notes:
|
||||
|
||||
The hccl.json file specified by [RANK_TABLE_FILE] is used when running distributed tasks. You can use [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate this file.
|
||||
|
||||
### Distributed Training
|
||||
|
||||
- Run distributed train in ascend processor environment
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train8p.sh [RANK_TABLE_FILE] [PROJECT_PATH]
|
||||
```
|
||||
|
||||
### Training Result
|
||||
|
||||
The training results will be saved in the example path, The folder name starts with "ICNet-".You can find the checkpoint file and similar results below in LOG(0-7)/log.txt.
|
||||
|
@ -174,7 +193,7 @@ epoch time: 97117.785 ms, per step time: 1044.277 ms
|
|||
Check the checkpoint path used for evaluation before running the following command.
|
||||
|
||||
```shell
|
||||
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH]
|
||||
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### Evaluation Result
|
||||
|
@ -196,7 +215,7 @@ avgtime 0.19648232793807982
|
|||
bash run_infer_310.sh [The path of the MINDIR for 310 infer] [The path of the dataset for 310 infer] 0
|
||||
```
|
||||
|
||||
Note:: Before executing 310 infer, create the MINDIR/AIR model using "python export.py --ckpt-file [The path of the CKPT for exporting]".
|
||||
- Note: Before executing 310 infer, create the MINDIR/AIR model using "python export.py --ckpt-file [The path of the CKPT for exporting]".
|
||||
|
||||
# [Model Description](#Content)
|
||||
|
||||
|
@ -204,7 +223,7 @@ Note:: Before executing 310 infer, create the MINDIR/AIR model using "python exp
|
|||
|
||||
### Training Performance
|
||||
|
||||
|Parameter | MaskRCNN |
|
||||
|Parameter | ICNet |
|
||||
| ------------------- | --------------------------------------------------------- |
|
||||
|resources | Ascend 910;CPU 2.60GHz, 192core;memory:755G |
|
||||
|Upload date |2021.6.1 |
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
export SERVER_ID=0
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
first_device=0
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$((first_device+i))
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_url=$PATH2 &> log &
|
||||
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / num_classes, mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
|
||||
cfg = ed({
|
||||
'optimizer': 'Momentum',
|
||||
})
|
||||
|
||||
config1 = ed({
|
||||
"class_num": 1000,
|
||||
"batch_size": 256,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 150,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 5,
|
||||
"decay_mode": "linear",
|
||||
"save_checkpoint_path": "./checkpoints",
|
||||
"hold_epochs": 0,
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.8,
|
||||
"lr_end": 0.0
|
||||
})
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32,
|
||||
target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
# rank_size = int(os.getenv("RANK_SIZE", default=1))
|
||||
rank_size = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _generate_linear_lr(lr_init, lr_end, total_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
total_steps(int): all steps in training.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
lr = lr_init - (lr_init - lr_end) * (i) / (total_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return lr_each_step
|
||||
|
||||
def _generate_cosine_lr(lr_init, total_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_init * decayed
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
def get_lr(lr_init, lr_end, total_epochs, steps_per_epoch, decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
if decay_mode == "cosine":
|
||||
lr_each_step = _generate_cosine_lr(lr_init, total_steps)
|
||||
else:
|
||||
lr_each_step = _generate_linear_lr(lr_init, lr_end, total_steps)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
|
@ -0,0 +1,288 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""pretrained model resnet50"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore as ms
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class BottleneckV1b(nn.Cell):
|
||||
"""BottleneckV1b"""
|
||||
def __init__(self, in_channel, out_channel, stride, dilation=1):
|
||||
super().__init__()
|
||||
expansion = 4
|
||||
|
||||
# middle channel num
|
||||
channel = out_channel // expansion
|
||||
self.conv1 = nn.Conv2dBnAct(in_channel, channel, kernel_size=1, stride=1,
|
||||
has_bn=True, pad_mode="same", activation='relu')
|
||||
|
||||
self.conv2 = nn.Conv2dBnAct(channel, channel, kernel_size=3, stride=stride,
|
||||
dilation=dilation, has_bn=True, pad_mode="same", activation='relu')
|
||||
|
||||
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bn=True)
|
||||
|
||||
# whether down-sample identity
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
|
||||
self.down_layer = None
|
||||
if self.down_sample:
|
||||
self.down_layer = nn.Conv2dBnAct(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', has_bn=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.add = ms.ops.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Resnet50v1b(nn.Cell):
|
||||
"""Resnet50v1b"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(Resnet50v1b, self).__init__()
|
||||
|
||||
# initial stage
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0],
|
||||
dilation=1)
|
||||
self.layer2 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1],
|
||||
dilation=1)
|
||||
self.layer3 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
dilation=2)
|
||||
self.layer4 = self._make_layer(block=block,
|
||||
layer_num=layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3],
|
||||
dilation=4)
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, dilation):
|
||||
"""make layers"""
|
||||
layers = []
|
||||
|
||||
resblock = block(in_channel=in_channel,
|
||||
out_channel=out_channel,
|
||||
stride=stride,
|
||||
dilation=dilation)
|
||||
layers.append(resblock)
|
||||
for _ in range(1, layer_num):
|
||||
resblock = block(in_channel=out_channel,
|
||||
out_channel=out_channel,
|
||||
stride=1,
|
||||
dilation=dilation)
|
||||
layers.append(resblock)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""initial stage"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
# four groups
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True):
|
||||
"""
|
||||
Get SE-ResNet50 neural network.
|
||||
Default : GE Theta+ version (best)
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
Returns:
|
||||
Cell, cell instance of GENet-ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = get_resnet50v1b(1001)
|
||||
"""
|
||||
|
||||
model = Resnet50v1b(block=BottleneckV1b,
|
||||
layer_nums=[3, 4, 6, 3],
|
||||
in_channels=[64, 256, 512, 1024],
|
||||
out_channels=[256, 512, 1024, 2048],
|
||||
strides=[1, 2, 2, 2],
|
||||
num_classes=class_num)
|
||||
|
||||
if pretrained:
|
||||
pretrained_ckpt = ckpt_root
|
||||
param_dict = load_checkpoint(pretrained_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print("pretrained....")
|
||||
|
||||
return model
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train GENet."""
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.resnet50_v1 import get_resnet50v1b as net
|
||||
from src.lr_generator import get_lr
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config1 as config
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
|
||||
help="Device target, support Ascend, GPU and CPU.")
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_modelarts', type=str, default="False", help='is train on modelarts')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.is_modelarts == "True":
|
||||
import moxing as mox
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
|
||||
"""remove useless parameters according to filter_list"""
|
||||
for key in list(origin_dict.keys()):
|
||||
for name in param_filter:
|
||||
if name in key:
|
||||
print("Delete parameter from checkpoint: ", key)
|
||||
del origin_dict[key]
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
local_train_data_url = args_opt.data_url
|
||||
|
||||
if args_opt.is_modelarts == "True":
|
||||
local_summary_dir = "/cache/summary"
|
||||
local_data_url = "/cache/data"
|
||||
local_train_url = "/cache/ckpt"
|
||||
local_zipfolder_url = "/cache/tarzip"
|
||||
ckpt_save_dir = local_train_url
|
||||
mox.file.make_dirs(local_train_url)
|
||||
mox.file.make_dirs(local_summary_dir)
|
||||
filename = "imagenet_original.tar.gz"
|
||||
# transfer dataset
|
||||
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||
mox.file.make_dirs(local_data_url)
|
||||
local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
|
||||
obs_zip_path = os.path.join(args_opt.data_url, filename)
|
||||
mox.file.copy(obs_zip_path, local_zip_path)
|
||||
unzip_command = "tar -xvf %s -C %s" % (local_zip_path, local_data_url)
|
||||
os.system(unzip_command)
|
||||
local_train_data_url = os.path.join(local_data_url, "imagenet_original", "train")
|
||||
|
||||
target = args_opt.device_target
|
||||
if target != 'Ascend':
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
run_distribute = False
|
||||
|
||||
if device_num > 1:
|
||||
run_distribute = True
|
||||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
|
||||
if run_distribute:
|
||||
|
||||
context.set_context(device_id=device_id,
|
||||
enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
|
||||
init()
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=local_train_data_url, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target, distribute=run_distribute)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
|
||||
net = net(class_num=config.class_num, pretrained=False)
|
||||
|
||||
# init weight
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
|
||||
lr = get_lr(config.lr_init, config.lr_end, config.epoch_size, step_size, config.decay_mode)
|
||||
|
||||
lr = Tensor(lr)
|
||||
|
||||
# define opt
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
|
||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
# define loss, model
|
||||
if target == "Ascend":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor,
|
||||
num_classes=config.class_num)
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
|
||||
metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False)
|
||||
else:
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
|
||||
cb = [time_cb, loss_cb]
|
||||
|
||||
if rank_id == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="ResNet50V1B", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
dataset_sink_mode = target != "CPU"
|
||||
model.train(config.epoch_size, dataset, callbacks=cb,
|
||||
sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
if device_id == 0 and args_opt.is_modelarts == "True":
|
||||
mox.file.copy_parallel(ckpt_save_dir, args_opt.train_url)
|
|
@ -74,7 +74,6 @@ class Evaluator:
|
|||
mask = self._mask_transform(mask) # mask shape: (H,w)
|
||||
|
||||
image = Tensor(image)
|
||||
print(image)
|
||||
|
||||
expand_dims = ops.ExpandDims()
|
||||
image = expand_dims(image, 0)
|
||||
|
@ -84,8 +83,8 @@ class Evaluator:
|
|||
end_time = time.time()
|
||||
step_time = end_time - start_time
|
||||
|
||||
expand_dims = ops.ExpandDims()
|
||||
mask = expand_dims(mask, 0)
|
||||
output = np.array(output)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
self.metric.update(output, mask)
|
||||
list_time.append(step_time)
|
||||
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH]"
|
||||
echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -53,7 +53,7 @@ fi
|
|||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_ID=$4
|
||||
export RANK_SIZE=1
|
||||
export RANK_ID=0
|
||||
|
||||
|
@ -68,6 +68,6 @@ cp -r ../src ./eval
|
|||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --project_path=$PATH3 &> log &
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --project_path=$PATH3 --device=$4 &> log &
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
### 1.Model
|
||||
model:
|
||||
name: "icnet"
|
||||
backbone: "resnet50"
|
||||
backbone: "resnet50v1"
|
||||
base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0]
|
||||
crop_size: 960 # end of augmentation, crop to training
|
||||
|
||||
|
|
Loading…
Reference in New Issue