From aad655b77870b030868ed1430f3f2e74a1f048ff Mon Sep 17 00:00:00 2001 From: xinping li <2284498467@qq.com> Date: Tue, 3 Aug 2021 18:23:41 +0800 Subject: [PATCH] add pre-training scripts --- model_zoo/research/cv/ICNet/README.md | 41 ++- .../scripts/run_distribute_train.sh | 71 +++++ .../Res50V1_PRE/src/CrossEntropySmooth.py | 38 +++ .../cv/ICNet/Res50V1_PRE/src/config.py | 42 +++ .../cv/ICNet/Res50V1_PRE/src/dataset.py | 108 +++++++ .../cv/ICNet/Res50V1_PRE/src/lr_generator.py | 84 +++++ .../cv/ICNet/Res50V1_PRE/src/resnet50_v1.py | 288 ++++++++++++++++++ .../research/cv/ICNet/Res50V1_PRE/train.py | 189 ++++++++++++ model_zoo/research/cv/ICNet/eval.py | 5 +- .../research/cv/ICNet/scripts/run_eval.sh | 8 +- .../cv/ICNet/src/model_utils/icnet.yaml | 2 +- 11 files changed, 857 insertions(+), 19 deletions(-) create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/scripts/run_distribute_train.sh create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/src/CrossEntropySmooth.py create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/src/config.py create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/src/dataset.py create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/src/lr_generator.py create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/src/resnet50_v1.py create mode 100644 model_zoo/research/cv/ICNet/Res50V1_PRE/train.py diff --git a/model_zoo/research/cv/ICNet/README.md b/model_zoo/research/cv/ICNet/README.md index c2496b09bd7..8b330c1d874 100644 --- a/model_zoo/research/cv/ICNet/README.md +++ b/model_zoo/research/cv/ICNet/README.md @@ -23,7 +23,7 @@ ICNet(Image Cascade Network) propose a full convolution network which incorporates multi-resolution branches under proper label guidance to address the challenge of real-time semantic segmentation. -[paper](https://arxiv.org/abs/1704.08545)ECCV2018 +[paper](https://arxiv.org/abs/1704.08545) from ECCV2018 # [Model Architecture](#Contents) @@ -31,7 +31,7 @@ ICNet takes cascade image inputs (i.e., low-, medium- and high resolution images # [Dataset](#Content) -used Dataset :[Cityscape Dataset Website](https://www.cityscapes-dataset.com/) +used Dataset :[Cityscape Dataset Website](https://www.cityscapes-dataset.com/) (please download 1st and 3rd zip) It contains 5,000 finely annotated images split into training, validation and testing sets with 2,975, 500, and 1,525 images respectively. @@ -64,6 +64,16 @@ It contains 5,000 finely annotated images split into training, validation and te ├── export.py # export mindir ├── postprocess.py # 310 infer calculate accuracy ├── README.md # descriptions about ICNet + ├── Res50V1_PRE # scripts for pretrain + │   ├── scripts + │   │   └── run_distribute_train.sh + │   ├── src + │   │   ├── config.py + │   │   ├── CrossEntropySmooth.py + │   │   ├── dataset.py + │   │   ├── lr_generator.py + │   │   └── resnet50_v1.py + │   └── train.py ├── scripts │   ├── run_distribute_train8p.sh # multi cards distributed training in ascend │   ├── run_eval.sh # validation script @@ -95,7 +105,7 @@ Set script parameters in src/model_utils/icnet.yaml . ```bash name: "icnet" -backbone: "resnet50" +backbone: "resnet50v1" base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0] crop_size: 960 # end of augmentation, crop to training ``` @@ -116,9 +126,8 @@ valid_batch_size: 1 cityscapes_root: "/data/cityscapes/" # set dataset path epochs: 160 val_epoch: 1 -ckpt_dir: "./ckpt/" # ckpt and training log will be saved here mindrecord_dir: '' # set mindrecord path -pretrained_model_path: '/root/ResNet50V1B-150_625.ckpt' # set the pretrained model path correctly +pretrained_model_path: '/root/ResNet50V1B-150_625.ckpt' # use the latest checkpoint file after pre-training save_checkpoint_epochs: 5 keep_checkpoint_max: 10 ``` @@ -137,18 +146,28 @@ keep_checkpoint_max: 10 [MINDRCORD_PATH] in script should be consistent with 'mindrecord_dir' in config file. -### Distributed Training +### Pre-training -- Run distributed train in ascend processor environment +The folder Res50V1_PRE contains the scripts for pre-training and its dataset is [image net](https://image-net.org/). More details in [GENet_Res50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/GENet_Res50) + +- Usage: ```shell - bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PROJECT_PATH] + bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] ``` - Notes: The hccl.json file specified by [RANK_TABLE_FILE] is used when running distributed tasks. You can use [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate this file. +### Distributed Training + +- Run distributed train in ascend processor environment + +```shell + bash scripts/run_distribute_train8p.sh [RANK_TABLE_FILE] [PROJECT_PATH] +``` + ### Training Result The training results will be saved in the example path, The folder name starts with "ICNet-".You can find the checkpoint file and similar results below in LOG(0-7)/log.txt. @@ -174,7 +193,7 @@ epoch time: 97117.785 ms, per step time: 1044.277 ms Check the checkpoint path used for evaluation before running the following command. ```shell - bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH] + bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH] [DEVICE_ID] ``` ### Evaluation Result @@ -196,7 +215,7 @@ avgtime 0.19648232793807982 bash run_infer_310.sh [The path of the MINDIR for 310 infer] [The path of the dataset for 310 infer] 0 ``` -Note:: Before executing 310 infer, create the MINDIR/AIR model using "python export.py --ckpt-file [The path of the CKPT for exporting]". +- Note: Before executing 310 infer, create the MINDIR/AIR model using "python export.py --ckpt-file [The path of the CKPT for exporting]". # [Model Description](#Content) @@ -204,7 +223,7 @@ Note:: Before executing 310 infer, create the MINDIR/AIR model using "python exp ### Training Performance -|Parameter | MaskRCNN | +|Parameter | ICNet | | ------------------- | --------------------------------------------------------- | |resources | Ascend 910;CPU 2.60GHz, 192core;memory:755G | |Upload date |2021.6.1 | diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/scripts/run_distribute_train.sh b/model_zoo/research/cv/ICNet/Res50V1_PRE/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..62b9cbf4db1 --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/scripts/run_distribute_train.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]" + exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -d $PATH2 ] +then + echo "error: DATASET_PATH=$PATH2 is not a directory" +exit 1 +fi + + +export SERVER_ID=0 +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +rank_start=$((DEVICE_NUM * SERVER_ID)) +first_device=0 +export RANK_TABLE_FILE=$PATH1 + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$((first_device+i)) + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py --data_url=$PATH2 &> log & + + cd .. +done diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/src/CrossEntropySmooth.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/CrossEntropySmooth.py new file mode 100644 index 00000000000..24bb6995ed8 --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/CrossEntropySmooth.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +class CrossEntropySmooth(_Loss): + """CrossEntropy""" + def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): + super(CrossEntropySmooth, self).__init__() + self.onehot = P.OneHot() + self.sparse = sparse + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / num_classes, mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) + + def construct(self, logit, label): + if self.sparse: + label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, label) + return loss diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/src/config.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/config.py new file mode 100644 index 00000000000..44e5dfb6d68 --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/config.py @@ -0,0 +1,42 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed +# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional. +cfg = ed({ + 'optimizer': 'Momentum', + }) + +config1 = ed({ + "class_num": 1000, + "batch_size": 256, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 150, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 5, + "decay_mode": "linear", + "save_checkpoint_path": "./checkpoints", + "hold_epochs": 0, + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0.8, + "lr_end": 0.0 +}) diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/src/dataset.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/dataset.py new file mode 100644 index 00000000000..3f032c27b04 --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/dataset.py @@ -0,0 +1,108 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from mindspore.communication.management import init, get_rank, get_group_size + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, + target="Ascend", distribute=False): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + + if device_num == 1: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + + # apply batch operations + data_set = data_set.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + + return data_set + + + +def _get_rank_info(): + """ + get rank size and rank id + """ + # rank_size = int(os.getenv("RANK_SIZE", default=1)) + rank_size = int(os.getenv("RANK_SIZE")) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/src/lr_generator.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/lr_generator.py new file mode 100644 index 00000000000..5806b43c01e --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/lr_generator.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def _generate_linear_lr(lr_init, lr_end, total_steps): + """ + Applies liner decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + total_steps(int): all steps in training. + + Returns: + np.array, learning rate array. + """ + lr_each_step = [] + for i in range(total_steps): + lr = lr_init - (lr_init - lr_end) * (i) / (total_steps) + lr_each_step.append(lr) + + return lr_each_step + +def _generate_cosine_lr(lr_init, total_steps): + """ + Applies cosine decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + decay_steps = total_steps + lr_each_step = [] + for i in range(total_steps): + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_init * decayed + lr_each_step.append(lr) + return lr_each_step + +def get_lr(lr_init, lr_end, total_epochs, steps_per_epoch, decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + + total_steps = steps_per_epoch * total_epochs + if decay_mode == "cosine": + lr_each_step = _generate_cosine_lr(lr_init, total_steps) + else: + lr_each_step = _generate_linear_lr(lr_init, lr_end, total_steps) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + + return lr_each_step diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/src/resnet50_v1.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/resnet50_v1.py new file mode 100644 index 00000000000..19d248be3ab --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/src/resnet50_v1.py @@ -0,0 +1,288 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pretrained model resnet50""" +import math +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +import mindspore as ms +from mindspore import load_checkpoint, load_param_into_net + + +def calculate_gain(nonlinearity, param=None): + """calculate_gain""" + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + res = 0 + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + res = 1 + elif nonlinearity == 'tanh': + res = 5.0 / 3 + elif nonlinearity == 'relu': + res = math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + res = math.sqrt(2.0 / (1 + negative_slope ** 2)) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + return res + + +def _calculate_fan_in_and_fan_out(tensor): + """_calculate_fan_in_and_fan_out""" + dimensions = len(tensor) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + if dimensions == 2: # Linear + fan_in = tensor[1] + fan_out = tensor[0] + else: + num_input_fmaps = tensor[1] + num_output_fmaps = tensor[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = tensor[2] * tensor[3] + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return np.random.normal(0, std, size=inputs_shape).astype(np.float32) + + +def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + + return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, + padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + + return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, + padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class BottleneckV1b(nn.Cell): + """BottleneckV1b""" + def __init__(self, in_channel, out_channel, stride, dilation=1): + super().__init__() + expansion = 4 + + # middle channel num + channel = out_channel // expansion + self.conv1 = nn.Conv2dBnAct(in_channel, channel, kernel_size=1, stride=1, + has_bn=True, pad_mode="same", activation='relu') + + self.conv2 = nn.Conv2dBnAct(channel, channel, kernel_size=3, stride=stride, + dilation=dilation, has_bn=True, pad_mode="same", activation='relu') + + self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', + has_bn=True) + + # whether down-sample identity + self.down_sample = False + if stride != 1 or in_channel != out_channel: + self.down_sample = True + + self.down_layer = None + if self.down_sample: + self.down_layer = nn.Conv2dBnAct(in_channel, out_channel, + kernel_size=1, stride=stride, + pad_mode='same', has_bn=True) + self.relu = nn.ReLU() + self.add = ms.ops.Add() + + def construct(self, x): + """construct""" + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.down_sample: + identity = self.down_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class Resnet50v1b(nn.Cell): + """Resnet50v1b""" + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(Resnet50v1b, self).__init__() + + # initial stage + self.conv1 = _conv7x7(3, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block=block, + layer_num=layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0], + dilation=1) + self.layer2 = self._make_layer(block=block, + layer_num=layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1], + dilation=1) + self.layer3 = self._make_layer(block=block, + layer_num=layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2], + dilation=2) + self.layer4 = self._make_layer(block=block, + layer_num=layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3], + dilation=4) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, dilation): + """make layers""" + layers = [] + + resblock = block(in_channel=in_channel, + out_channel=out_channel, + stride=stride, + dilation=dilation) + layers.append(resblock) + for _ in range(1, layer_num): + resblock = block(in_channel=out_channel, + out_channel=out_channel, + stride=1, + dilation=dilation) + layers.append(resblock) + return nn.SequentialCell(layers) + + def construct(self, x): + """initial stage""" + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + # four groups + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True): + """ + Get SE-ResNet50 neural network. + Default : GE Theta+ version (best) + + Args: + class_num (int): Class number. + Returns: + Cell, cell instance of GENet-ResNet50 neural network. + + Examples: + >>> net = get_resnet50v1b(1001) + """ + + model = Resnet50v1b(block=BottleneckV1b, + layer_nums=[3, 4, 6, 3], + in_channels=[64, 256, 512, 1024], + out_channels=[256, 512, 1024, 2048], + strides=[1, 2, 2, 2], + num_classes=class_num) + + if pretrained: + pretrained_ckpt = ckpt_root + param_dict = load_checkpoint(pretrained_ckpt) + load_param_into_net(model, param_dict) + print("pretrained....") + + return model diff --git a/model_zoo/research/cv/ICNet/Res50V1_PRE/train.py b/model_zoo/research/cv/ICNet/Res50V1_PRE/train.py new file mode 100644 index 00000000000..29a6e801133 --- /dev/null +++ b/model_zoo/research/cv/ICNet/Res50V1_PRE/train.py @@ -0,0 +1,189 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train GENet.""" +import os +import argparse +from mindspore import context +from mindspore import Tensor +from mindspore.nn.optim import Momentum +from mindspore.train.model import Model +from mindspore.context import ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.callback import LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init +from mindspore.common import set_seed +from mindspore.parallel import set_algo_parameters +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +from src.CrossEntropySmooth import CrossEntropySmooth +from src.resnet50_v1 import get_resnet50v1b as net +from src.lr_generator import get_lr +from src.dataset import create_dataset +from src.config import config1 as config + +parser = argparse.ArgumentParser(description='Image classification') + +parser.add_argument('--data_url', type=str, default=None, help='Dataset path') +parser.add_argument('--train_url', type=str, default=None, help='Train output path') +parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"), + help="Device target, support Ascend, GPU and CPU.") +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--is_modelarts', type=str, default="False", help='is train on modelarts') +args_opt = parser.parse_args() + +if args_opt.is_modelarts == "True": + import moxing as mox + +set_seed(1) + +def filter_checkpoint_parameter_by_list(origin_dict, param_filter): + """remove useless parameters according to filter_list""" + for key in list(origin_dict.keys()): + for name in param_filter: + if name in key: + print("Delete parameter from checkpoint: ", key) + del origin_dict[key] + break + + +if __name__ == '__main__': + + device_id = int(os.getenv('DEVICE_ID')) + device_num = int(os.getenv("RANK_SIZE")) + + ckpt_save_dir = config.save_checkpoint_path + local_train_data_url = args_opt.data_url + + if args_opt.is_modelarts == "True": + local_summary_dir = "/cache/summary" + local_data_url = "/cache/data" + local_train_url = "/cache/ckpt" + local_zipfolder_url = "/cache/tarzip" + ckpt_save_dir = local_train_url + mox.file.make_dirs(local_train_url) + mox.file.make_dirs(local_summary_dir) + filename = "imagenet_original.tar.gz" + # transfer dataset + local_data_url = os.path.join(local_data_url, str(device_id)) + mox.file.make_dirs(local_data_url) + local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename) + obs_zip_path = os.path.join(args_opt.data_url, filename) + mox.file.copy(obs_zip_path, local_zip_path) + unzip_command = "tar -xvf %s -C %s" % (local_zip_path, local_data_url) + os.system(unzip_command) + local_train_data_url = os.path.join(local_data_url, "imagenet_original", "train") + + target = args_opt.device_target + if target != 'Ascend': + raise ValueError("Unsupported device target.") + + run_distribute = False + + if device_num > 1: + run_distribute = True + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + + if run_distribute: + + context.set_context(device_id=device_id, + enable_auto_mixed_precision=True) + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + set_algo_parameters(elementwise_op_strategy_follow=True) + context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) + init() + + # create dataset + dataset = create_dataset(dataset_path=local_train_data_url, do_train=True, repeat_num=1, + batch_size=config.batch_size, target=target, distribute=run_distribute) + step_size = dataset.get_dataset_size() + + # define net + + net = net(class_num=config.class_num, pretrained=False) + + # init weight + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + + load_param_into_net(net, param_dict) + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + + lr = get_lr(config.lr_init, config.lr_end, config.epoch_size, step_size, config.decay_mode) + + lr = Tensor(lr) + + # define opt + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + + opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) + # define loss, model + if target == "Ascend": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, + num_classes=config.class_num) + + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, + metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False) + else: + raise ValueError("Unsupported device target.") + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + rank_id = int(os.getenv("RANK_ID")) + + cb = [time_cb, loss_cb] + + if rank_id == 0: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="ResNet50V1B", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + + dataset_sink_mode = target != "CPU" + model.train(config.epoch_size, dataset, callbacks=cb, + sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode) + + if device_id == 0 and args_opt.is_modelarts == "True": + mox.file.copy_parallel(ckpt_save_dir, args_opt.train_url) diff --git a/model_zoo/research/cv/ICNet/eval.py b/model_zoo/research/cv/ICNet/eval.py index bccbb3ed434..e2ab20fac6e 100644 --- a/model_zoo/research/cv/ICNet/eval.py +++ b/model_zoo/research/cv/ICNet/eval.py @@ -74,7 +74,6 @@ class Evaluator: mask = self._mask_transform(mask) # mask shape: (H,w) image = Tensor(image) - print(image) expand_dims = ops.ExpandDims() image = expand_dims(image, 0) @@ -84,8 +83,8 @@ class Evaluator: end_time = time.time() step_time = end_time - start_time - expand_dims = ops.ExpandDims() - mask = expand_dims(mask, 0) + output = np.array(output) + mask = np.expand_dims(mask, axis=0) self.metric.update(output, mask) list_time.append(step_time) diff --git a/model_zoo/research/cv/ICNet/scripts/run_eval.sh b/model_zoo/research/cv/ICNet/scripts/run_eval.sh index 74495640f9a..396d49719d2 100644 --- a/model_zoo/research/cv/ICNet/scripts/run_eval.sh +++ b/model_zoo/research/cv/ICNet/scripts/run_eval.sh @@ -14,9 +14,9 @@ # limitations under the License. # ============================================================================ -if [ $# != 3 ] +if [ $# != 4 ] then - echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH]" + echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PROJECT_PATH] [DEVICE_ID]" exit 1 fi @@ -53,7 +53,7 @@ fi ulimit -u unlimited export DEVICE_NUM=1 -export DEVICE_ID=0 +export DEVICE_ID=$4 export RANK_SIZE=1 export RANK_ID=0 @@ -68,6 +68,6 @@ cp -r ../src ./eval cd ./eval || exit env > env.log echo "start evaluation for device $DEVICE_ID" -python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --project_path=$PATH3 &> log & +python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --project_path=$PATH3 --device=$4 &> log & cd .. diff --git a/model_zoo/research/cv/ICNet/src/model_utils/icnet.yaml b/model_zoo/research/cv/ICNet/src/model_utils/icnet.yaml index 9fc8d38a8a5..649ff114b8d 100644 --- a/model_zoo/research/cv/ICNet/src/model_utils/icnet.yaml +++ b/model_zoo/research/cv/ICNet/src/model_utils/icnet.yaml @@ -1,7 +1,7 @@ ### 1.Model model: name: "icnet" - backbone: "resnet50" + backbone: "resnet50v1" base_size: 1024 # during augmentation, shorter size will be resized between [base_size*0.5, base_size*2.0] crop_size: 960 # end of augmentation, crop to training